1use crate::stream::{BoxStream, Flow, NotUsed, Sink, Source, StreamCompletion};
2use crate::{StreamError, StreamResult};
3use futures::{FutureExt, channel::oneshot};
4use std::future::Future;
5use std::net::SocketAddr;
6use std::panic::AssertUnwindSafe;
7use std::path::PathBuf;
8use std::sync::{
9 Arc, Mutex,
10 atomic::{AtomicBool, Ordering},
11 mpsc as std_mpsc,
12};
13use std::thread::{self, Thread};
14use std::time::Duration;
15use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
16use tokio::net::{TcpListener, TcpStream, ToSocketAddrs};
17use tokio::sync::{mpsc, watch};
18
19const DEFAULT_CHUNK_SIZE: usize = 8192;
20const FILE_READ_AHEAD_CHUNKS: usize = 8;
21const FILE_INTERNAL_READ_SIZE: usize = 256 * 1024;
22const TCP_READ_AHEAD_CHUNKS: usize = 1;
23const PARK_INTERVAL: Duration = Duration::from_millis(1);
24const READ_READY_SPINS: usize = 256;
25const BACKPRESSURE_READY_SPINS: usize = 64;
26const BACKPRESSURE_PARK: Duration = Duration::from_micros(10);
27
28#[derive(Clone)]
36struct ConsumerWaker {
37 thread: Arc<Mutex<Option<Thread>>>,
38}
39
40impl ConsumerWaker {
41 fn new() -> Self {
42 Self {
43 thread: Arc::new(Mutex::new(None)),
44 }
45 }
46
47 fn capture_current(&self) {
48 let mut slot = self.thread.lock().expect("consumer waker poisoned");
49 if slot.is_none() {
50 *slot = Some(thread::current());
51 }
52 }
53
54 fn unpark(&self) {
55 let slot = self.thread.lock().expect("consumer waker poisoned");
56 if let Some(t) = slot.as_ref() {
57 t.unpark();
58 }
59 }
60}
61
62fn io_error(error: std::io::Error) -> StreamError {
63 StreamError::Failed(error.to_string())
64}
65
66fn write_zero_error() -> StreamError {
67 StreamError::Failed("async writer returned zero bytes".to_owned())
68}
69
70#[derive(Debug, Clone, PartialEq, Eq)]
77pub struct IoResult {
78 pub bytes: u64,
79 pub status: StreamResult<()>,
80}
81
82impl IoResult {
83 #[must_use]
85 pub fn succeeded(bytes: u64) -> Self {
86 Self {
87 bytes,
88 status: Ok(()),
89 }
90 }
91
92 #[must_use]
94 pub fn failed(bytes: u64, error: StreamError) -> Self {
95 Self {
96 bytes,
97 status: Err(error),
98 }
99 }
100
101 #[must_use]
103 pub fn bytes(&self) -> u64 {
104 self.bytes
105 }
106
107 #[must_use = "the terminal status should be inspected"]
109 pub fn status(&self) -> StreamResult<()> {
110 self.status.clone()
111 }
112
113 #[must_use]
115 pub fn is_success(&self) -> bool {
116 self.status.is_ok()
117 }
118}
119
120pub type TokioByteSource = Source<Vec<u8>, StreamCompletion<IoResult>>;
122pub type TokioByteSink = Sink<Vec<u8>, StreamCompletion<IoResult>>;
124
125#[derive(Clone)]
126enum DemandTerminal {
127 Complete,
128 Error(StreamError),
129}
130
131enum DemandResponse<T> {
132 Item(T),
133 Complete,
134 Error(StreamError),
135}
136
137struct DemandSourceStream<T> {
138 demands: mpsc::Sender<std_mpsc::Sender<DemandResponse<T>>>,
139 cancel: watch::Sender<bool>,
140 terminal: Arc<Mutex<Option<DemandTerminal>>>,
141 done: bool,
142}
143
144impl<T> DemandSourceStream<T> {
145 fn terminal_response(&self) -> Option<Option<StreamResult<T>>> {
146 self.terminal
147 .lock()
148 .expect("tokio source terminal poisoned")
149 .clone()
150 .map(|terminal| match terminal {
151 DemandTerminal::Complete => None,
152 DemandTerminal::Error(error) => Some(Err(error)),
153 })
154 }
155
156 fn mark_done(&mut self) {
157 self.done = true;
158 let _ = self.cancel.send(true);
159 }
160}
161
162impl<T: Send + 'static> Iterator for DemandSourceStream<T> {
163 type Item = StreamResult<T>;
164
165 fn next(&mut self) -> Option<Self::Item> {
166 if self.done {
167 return None;
168 }
169
170 let stream_cancelled = crate::stream::current_stream_cancelled();
171 let (reply_sender, reply_receiver) = std_mpsc::channel();
172 if !send_bounded_demand(&self.demands, reply_sender, &stream_cancelled) {
173 self.mark_done();
174 return self
175 .terminal_response()
176 .unwrap_or(Some(Err(StreamError::AbruptTermination)));
177 }
178
179 loop {
180 if stream_cancelled
181 .as_ref()
182 .is_some_and(|cancelled| cancelled.load(Ordering::SeqCst))
183 {
184 self.mark_done();
185 return Some(Err(StreamError::Cancelled));
186 }
187
188 match reply_receiver.recv_timeout(PARK_INTERVAL) {
189 Ok(DemandResponse::Item(item)) => return Some(Ok(item)),
190 Ok(DemandResponse::Complete) => {
191 self.mark_done();
192 return None;
193 }
194 Ok(DemandResponse::Error(error)) => {
195 self.mark_done();
196 return Some(Err(error));
197 }
198 Err(std_mpsc::RecvTimeoutError::Timeout) => {}
199 Err(std_mpsc::RecvTimeoutError::Disconnected) => {
200 self.mark_done();
201 return self
202 .terminal_response()
203 .unwrap_or(Some(Err(StreamError::AbruptTermination)));
204 }
205 }
206 }
207 }
208}
209
210impl<T> Drop for DemandSourceStream<T> {
211 fn drop(&mut self) {
212 let _ = self.cancel.send(true);
213 }
214}
215
216struct BoundedByteSourceStream {
217 receiver: mpsc::Receiver<DemandResponse<Vec<u8>>>,
218 cancel: watch::Sender<bool>,
219 terminal: Arc<Mutex<Option<DemandTerminal>>>,
220 done: bool,
221 waker: ConsumerWaker,
222}
223
224impl BoundedByteSourceStream {
225 fn terminal_response(&self) -> Option<Option<StreamResult<Vec<u8>>>> {
226 self.terminal
227 .lock()
228 .expect("tokio source terminal poisoned")
229 .clone()
230 .map(|terminal| match terminal {
231 DemandTerminal::Complete => None,
232 DemandTerminal::Error(error) => Some(Err(error)),
233 })
234 }
235
236 fn mark_done(&mut self) {
237 self.done = true;
238 let _ = self.cancel.send(true);
239 }
240}
241
242impl Iterator for BoundedByteSourceStream {
243 type Item = StreamResult<Vec<u8>>;
244
245 fn next(&mut self) -> Option<Self::Item> {
246 if self.done {
247 return None;
248 }
249
250 self.waker.capture_current();
254
255 let stream_cancelled = crate::stream::current_stream_cancelled();
256 let mut spins = 0usize;
257 loop {
258 if stream_cancelled
259 .as_ref()
260 .is_some_and(|cancelled| cancelled.load(Ordering::SeqCst))
261 {
262 self.mark_done();
263 return Some(Err(StreamError::Cancelled));
264 }
265
266 match self.receiver.try_recv() {
267 Ok(DemandResponse::Item(item)) => return Some(Ok(item)),
268 Ok(DemandResponse::Complete) => {
269 self.mark_done();
270 return None;
271 }
272 Ok(DemandResponse::Error(error)) => {
273 self.mark_done();
274 return Some(Err(error));
275 }
276 Err(mpsc::error::TryRecvError::Empty) => read_wait(&mut spins),
277 Err(mpsc::error::TryRecvError::Disconnected) => {
278 self.mark_done();
279 return self
280 .terminal_response()
281 .unwrap_or(Some(Err(StreamError::AbruptTermination)));
282 }
283 }
284 }
285 }
286}
287
288impl Drop for BoundedByteSourceStream {
289 fn drop(&mut self) {
290 let _ = self.cancel.send(true);
291 }
292}
293
294fn send_bounded_demand<T>(
295 sender: &mpsc::Sender<T>,
296 mut message: T,
297 stream_cancelled: &Option<Arc<AtomicBool>>,
298) -> bool {
299 let mut spins = 0usize;
300 loop {
301 if stream_cancelled
302 .as_ref()
303 .is_some_and(|cancelled| cancelled.load(Ordering::SeqCst))
304 {
305 return false;
306 }
307
308 match sender.try_send(message) {
309 Ok(()) => return true,
310 Err(mpsc::error::TrySendError::Full(returned)) => {
311 message = returned;
312 backpressure_wait(&mut spins);
313 }
314 Err(mpsc::error::TrySendError::Closed(_)) => return false,
315 }
316 }
317}
318
319fn finish_terminal(terminal: &Arc<Mutex<Option<DemandTerminal>>>, value: DemandTerminal) {
320 let mut slot = terminal.lock().expect("tokio source terminal poisoned");
321 if slot.is_none() {
322 *slot = Some(value);
323 }
324}
325
326async fn next_demand<T>(
327 demands: &mut mpsc::Receiver<std_mpsc::Sender<DemandResponse<T>>>,
328 cancel: &mut watch::Receiver<bool>,
329) -> Option<std_mpsc::Sender<DemandResponse<T>>> {
330 if *cancel.borrow() {
331 return None;
332 }
333
334 tokio::select! {
335 demand = demands.recv() => demand,
336 changed = cancel.changed() => {
337 let _ = changed;
338 None
339 }
340 }
341}
342
343fn async_read_source<R, Fut>(
344 open: impl FnOnce() -> Fut + Send + 'static,
345 chunk_size: usize,
346 internal_read_size: usize,
347 read_ahead_chunks: usize,
348) -> (BoxStream<Vec<u8>>, StreamCompletion<IoResult>)
349where
350 R: AsyncRead + Unpin + Send + 'static,
351 Fut: Future<Output = std::io::Result<R>> + Send + 'static,
352{
353 assert!(chunk_size > 0, "chunk size must be greater than zero");
354 assert!(
355 read_ahead_chunks > 0,
356 "read-ahead bound must be greater than zero"
357 );
358 let internal_read_size = internal_read_size.max(chunk_size);
359 let (item_sender, item_receiver) = mpsc::channel(read_ahead_chunks);
360 let (cancel_sender, cancel_receiver) = watch::channel(false);
361 let (mat_sender, mat_receiver) = oneshot::channel();
362 let terminal = Arc::new(Mutex::new(None));
363 let terminal_for_task = Arc::clone(&terminal);
364 let waker = ConsumerWaker::new();
365 let producer_waker = waker.clone();
366
367 crate::stream::stream_tokio_runtime().spawn(async move {
368 let result = AssertUnwindSafe(run_async_read_task(
369 open(),
370 chunk_size,
371 internal_read_size,
372 item_sender,
373 cancel_receiver,
374 Arc::clone(&terminal_for_task),
375 producer_waker,
376 ))
377 .catch_unwind()
378 .await
379 .unwrap_or_else(|_| {
380 finish_terminal(
381 &terminal_for_task,
382 DemandTerminal::Error(StreamError::AbruptTermination),
383 );
384 Err(StreamError::AbruptTermination)
385 });
386 let _ = mat_sender.send(result);
387 });
388
389 (
390 Box::new(BoundedByteSourceStream {
391 receiver: item_receiver,
392 cancel: cancel_sender,
393 terminal,
394 done: false,
395 waker,
396 }) as BoxStream<Vec<u8>>,
397 StreamCompletion::from_receiver(mat_receiver, None),
398 )
399}
400
401async fn run_async_read_task<R, Fut>(
402 open: Fut,
403 chunk_size: usize,
404 internal_read_size: usize,
405 items: mpsc::Sender<DemandResponse<Vec<u8>>>,
406 mut cancel: watch::Receiver<bool>,
407 terminal: Arc<Mutex<Option<DemandTerminal>>>,
408 waker: ConsumerWaker,
409) -> StreamResult<IoResult>
410where
411 R: AsyncRead + Unpin + Send + 'static,
412 Fut: Future<Output = std::io::Result<R>> + Send + 'static,
413{
414 let mut bytes = 0_u64;
415 let mut reader = tokio::select! {
416 reader = open => match reader {
417 Ok(reader) => reader,
418 Err(error) => {
419 let error = io_error(error);
420 finish_terminal(&terminal, DemandTerminal::Error(error.clone()));
421 let _ = send_read_item(&items, DemandResponse::Error(error.clone()), &mut cancel, &waker).await;
422 return Ok(IoResult::failed(bytes, error));
423 }
424 },
425 changed = cancel.changed() => {
426 let _ = changed;
427 finish_terminal(&terminal, DemandTerminal::Error(StreamError::Cancelled));
428 return Ok(IoResult::failed(bytes, StreamError::Cancelled));
429 }
430 };
431
432 let mut buffer = vec![0_u8; internal_read_size];
433 let mut pending_tail = Vec::with_capacity(chunk_size);
434 loop {
435 let read = tokio::select! {
436 read = reader.read(&mut buffer) => read,
437 changed = cancel.changed() => {
438 let _ = changed;
439 finish_terminal(&terminal, DemandTerminal::Error(StreamError::Cancelled));
440 return Ok(IoResult::failed(bytes, StreamError::Cancelled));
441 }
442 };
443
444 match read {
445 Ok(0) => {
446 if !pending_tail.is_empty()
447 && !send_read_item(
448 &items,
449 DemandResponse::Item(std::mem::take(&mut pending_tail)),
450 &mut cancel,
451 &waker,
452 )
453 .await
454 {
455 finish_terminal(&terminal, DemandTerminal::Error(StreamError::Cancelled));
456 return Ok(IoResult::failed(bytes, StreamError::Cancelled));
457 }
458 finish_terminal(&terminal, DemandTerminal::Complete);
459 let _ = send_read_item(&items, DemandResponse::Complete, &mut cancel, &waker).await;
460 return Ok(IoResult::succeeded(bytes));
461 }
462 Ok(read) => {
463 bytes += read as u64;
464 if !send_read_chunks(
465 &items,
466 chunk_size,
467 &mut pending_tail,
468 &buffer[..read],
469 &mut cancel,
470 &waker,
471 )
472 .await
473 {
474 finish_terminal(&terminal, DemandTerminal::Error(StreamError::Cancelled));
475 return Ok(IoResult::failed(bytes, StreamError::Cancelled));
476 }
477 }
478 Err(error) => {
479 let error = io_error(error);
480 finish_terminal(&terminal, DemandTerminal::Error(error.clone()));
481 let _ = send_read_item(
482 &items,
483 DemandResponse::Error(error.clone()),
484 &mut cancel,
485 &waker,
486 )
487 .await;
488 return Ok(IoResult::failed(bytes, error));
489 }
490 }
491 }
492}
493
494async fn send_read_chunks(
495 sender: &mpsc::Sender<DemandResponse<Vec<u8>>>,
496 chunk_size: usize,
497 pending_tail: &mut Vec<u8>,
498 read_buffer: &[u8],
499 cancel: &mut watch::Receiver<bool>,
500 waker: &ConsumerWaker,
501) -> bool {
502 let mut offset = 0;
503 if !pending_tail.is_empty() {
504 let needed = chunk_size - pending_tail.len();
505 let take = needed.min(read_buffer.len());
506 pending_tail.extend_from_slice(&read_buffer[..take]);
507 offset += take;
508 if pending_tail.len() == chunk_size
509 && !send_read_item(
510 sender,
511 DemandResponse::Item(std::mem::take(pending_tail)),
512 cancel,
513 waker,
514 )
515 .await
516 {
517 return false;
518 }
519 }
520
521 while offset + chunk_size <= read_buffer.len() {
522 let next = offset + chunk_size;
523 if !send_read_item(
524 sender,
525 DemandResponse::Item(read_buffer[offset..next].to_vec()),
526 cancel,
527 waker,
528 )
529 .await
530 {
531 return false;
532 }
533 offset = next;
534 }
535
536 if offset < read_buffer.len() {
537 pending_tail.extend_from_slice(&read_buffer[offset..]);
538 }
539 true
540}
541
542async fn send_read_item<T>(
543 sender: &mpsc::Sender<DemandResponse<T>>,
544 item: DemandResponse<T>,
545 cancel: &mut watch::Receiver<bool>,
546 waker: &ConsumerWaker,
547) -> bool
548where
549 T: Send + 'static,
550{
551 let result = tokio::select! {
552 result = sender.send(item) => result,
553 changed = cancel.changed() => {
554 let _ = changed;
555 return false;
556 }
557 };
558 if result.is_ok() {
559 waker.unpark();
560 }
561 result.is_ok()
562}
563
564enum WriteCommand {
565 Chunk(Vec<u8>),
566 Finish(StreamResult<()>),
567}
568
569struct TokioCancelGuard {
570 cancel: watch::Sender<bool>,
571 armed: bool,
572}
573
574impl TokioCancelGuard {
575 fn new(cancel: watch::Sender<bool>) -> Self {
576 Self {
577 cancel,
578 armed: true,
579 }
580 }
581
582 fn disarm(&mut self) {
583 self.armed = false;
584 }
585}
586
587impl Drop for TokioCancelGuard {
588 fn drop(&mut self) {
589 if self.armed {
590 let _ = self.cancel.send(true);
591 }
592 }
593}
594
595fn async_write_sink<W, F, Fut>(open: F) -> TokioByteSink
596where
597 W: AsyncWrite + Unpin + Send + 'static,
598 F: Fn() -> Fut + Send + Sync + 'static,
599 Fut: Future<Output = std::io::Result<W>> + Send + 'static,
600{
601 let open = Arc::new(open);
602 Sink::from_runner(move |input, materializer| {
603 run_async_write_sink::<W, F, Fut>(input, materializer, Arc::clone(&open))
604 })
605}
606
607async fn run_async_write_task<W, Fut>(
608 open: Fut,
609 mut commands: mpsc::Receiver<WriteCommand>,
610 mut cancel: watch::Receiver<bool>,
611) -> StreamResult<IoResult>
612where
613 W: AsyncWrite + Unpin + Send + 'static,
614 Fut: Future<Output = std::io::Result<W>> + Send + 'static,
615{
616 let mut bytes = 0_u64;
617 let mut writer = tokio::select! {
618 writer = open => match writer {
619 Ok(writer) => writer,
620 Err(error) => return Ok(IoResult::failed(bytes, io_error(error))),
621 },
622 changed = cancel.changed() => {
623 let _ = changed;
624 return Ok(IoResult::failed(bytes, StreamError::Cancelled));
625 }
626 };
627
628 loop {
629 let command = tokio::select! {
630 command = commands.recv() => command,
631 changed = cancel.changed() => {
632 let _ = changed;
633 return Ok(IoResult::failed(bytes, StreamError::Cancelled));
634 }
635 };
636
637 match command {
638 Some(WriteCommand::Chunk(chunk)) => {
639 if let Err(error) = write_chunk(&mut writer, &chunk, &mut cancel, &mut bytes).await
640 {
641 return Ok(IoResult::failed(bytes, error));
642 }
643 }
644 Some(WriteCommand::Finish(upstream_status)) => {
645 let shutdown_status = shutdown_writer(&mut writer, &mut cancel).await;
646 return Ok(IoResult {
647 bytes,
648 status: upstream_status.and(shutdown_status),
649 });
650 }
651 None => {
652 let _ = shutdown_writer(&mut writer, &mut cancel).await;
653 return Ok(IoResult::failed(bytes, StreamError::Cancelled));
654 }
655 }
656 }
657}
658
659async fn write_chunk<W>(
660 writer: &mut W,
661 chunk: &[u8],
662 cancel: &mut watch::Receiver<bool>,
663 bytes: &mut u64,
664) -> StreamResult<()>
665where
666 W: AsyncWrite + Unpin,
667{
668 let mut offset = 0usize;
669 while offset < chunk.len() {
670 let written = tokio::select! {
671 written = writer.write(&chunk[offset..]) => written.map_err(io_error)?,
672 changed = cancel.changed() => {
673 let _ = changed;
674 return Err(StreamError::Cancelled);
675 }
676 };
677
678 if written == 0 {
679 return Err(write_zero_error());
680 }
681 offset += written;
682 *bytes += written as u64;
683 }
684 Ok(())
685}
686
687async fn shutdown_writer<W>(writer: &mut W, cancel: &mut watch::Receiver<bool>) -> StreamResult<()>
688where
689 W: AsyncWrite + Unpin,
690{
691 tokio::select! {
692 result = writer.flush() => result.map_err(io_error)?,
693 changed = cancel.changed() => {
694 let _ = changed;
695 return Err(StreamError::Cancelled);
696 }
697 }
698
699 tokio::select! {
700 result = writer.shutdown() => result.map_err(io_error),
701 changed = cancel.changed() => {
702 let _ = changed;
703 Err(StreamError::Cancelled)
704 }
705 }
706}
707
708fn feed_async_writer(
709 mut input: BoxStream<Vec<u8>>,
710 command_sender: mpsc::Sender<WriteCommand>,
711 done_receiver: std_mpsc::Receiver<StreamResult<IoResult>>,
712 cancelled: Arc<AtomicBool>,
713 cancel_sender: watch::Sender<bool>,
714) -> StreamResult<IoResult> {
715 let mut terminal = Ok(());
716 loop {
717 if cancelled.load(Ordering::SeqCst) {
718 terminal = Err(StreamError::Cancelled);
719 break;
720 }
721
722 match input.next() {
723 Some(Ok(chunk)) => {
724 if !send_write_command(&command_sender, WriteCommand::Chunk(chunk), &cancelled) {
725 break;
726 }
727 }
728 Some(Err(error)) => {
729 terminal = Err(error);
730 break;
731 }
732 None => break,
733 }
734 }
735
736 if cancelled.load(Ordering::SeqCst) {
737 let _ = cancel_sender.send(true);
738 } else {
739 let _ = send_write_command(&command_sender, WriteCommand::Finish(terminal), &cancelled);
740 }
741 drop(command_sender);
742
743 loop {
744 match done_receiver.recv_timeout(PARK_INTERVAL) {
745 Ok(result) => return result,
746 Err(std_mpsc::RecvTimeoutError::Timeout) => {
747 if cancelled.load(Ordering::SeqCst) {
748 let _ = cancel_sender.send(true);
752 }
753 }
754 Err(std_mpsc::RecvTimeoutError::Disconnected) => {
755 return Err(StreamError::AbruptTermination);
756 }
757 }
758 }
759}
760
761fn send_write_command(
762 sender: &mpsc::Sender<WriteCommand>,
763 mut command: WriteCommand,
764 cancelled: &AtomicBool,
765) -> bool {
766 let mut spins = 0usize;
767 loop {
768 if cancelled.load(Ordering::SeqCst) {
769 return false;
770 }
771
772 match sender.try_send(command) {
773 Ok(()) => return true,
774 Err(mpsc::error::TrySendError::Full(returned)) => {
775 command = returned;
776 backpressure_wait(&mut spins);
777 }
778 Err(mpsc::error::TrySendError::Closed(_)) => return false,
779 }
780 }
781}
782
783fn backpressure_wait(spins: &mut usize) {
784 if *spins < BACKPRESSURE_READY_SPINS {
785 *spins += 1;
786 thread::yield_now();
787 } else {
788 thread::park_timeout(BACKPRESSURE_PARK);
789 }
790}
791
792fn read_wait(spins: &mut usize) {
793 if *spins < READ_READY_SPINS {
794 *spins += 1;
795 thread::yield_now();
796 } else {
797 thread::park_timeout(PARK_INTERVAL);
798 }
799}
800
801fn tokio_file_read_source(
802 path: PathBuf,
803 chunk_size: usize,
804) -> (BoxStream<Vec<u8>>, StreamCompletion<IoResult>) {
805 async_read_source(
806 move || tokio::fs::File::open(path),
807 chunk_size,
808 FILE_INTERNAL_READ_SIZE,
809 FILE_READ_AHEAD_CHUNKS,
810 )
811}
812
813async fn tokio_file_write_open(path: Arc<PathBuf>) -> std::io::Result<tokio::fs::File> {
814 tokio::fs::OpenOptions::new()
815 .create(true)
816 .truncate(true)
817 .write(true)
818 .open(path.as_ref())
819 .await
820}
821
822fn run_async_write_sink<W, F, Fut>(
823 input: BoxStream<Vec<u8>>,
824 materializer: &crate::stream::Materializer,
825 open: Arc<F>,
826) -> StreamResult<StreamCompletion<IoResult>>
827where
828 W: AsyncWrite + Unpin + Send + 'static,
829 F: Fn() -> Fut + Send + Sync + 'static,
830 Fut: Future<Output = std::io::Result<W>> + Send + 'static,
831{
832 let (command_sender, command_receiver) = mpsc::channel(1);
833 let (cancel_sender, cancel_receiver) = watch::channel(false);
834 let (done_sender, done_receiver) = std_mpsc::sync_channel(1);
835
836 crate::stream::stream_tokio_runtime().spawn(async move {
837 let result = AssertUnwindSafe(run_async_write_task(
838 open(),
839 command_receiver,
840 cancel_receiver,
841 ))
842 .catch_unwind()
843 .await
844 .unwrap_or(Err(StreamError::AbruptTermination));
845 let _ = done_sender.send(result);
846 });
847
848 Ok(materializer.spawn_stream(move |cancelled| {
849 let mut guard = TokioCancelGuard::new(cancel_sender.clone());
850 let result = feed_async_writer(
851 input,
852 command_sender,
853 done_receiver,
854 cancelled,
855 cancel_sender,
856 );
857 guard.disarm();
858 result
859 }))
860}
861
862fn tokio_file_write_sink(path: PathBuf) -> TokioByteSink {
863 let path = Arc::new(path);
864 async_write_sink(move || tokio_file_write_open(Arc::clone(&path)))
865}
866
867pub struct TokioFileIO;
868
869impl TokioFileIO {
870 #[must_use]
880 pub fn from_path(path: impl Into<PathBuf>, chunk_size: usize) -> TokioByteSource {
881 assert!(chunk_size > 0, "chunk size must be greater than zero");
882 let path = path.into();
883 Source::from_materialized_factory(move |_materializer| {
884 let path = path.clone();
885 Ok(tokio_file_read_source(path, chunk_size))
886 })
887 }
888
889 #[must_use]
891 pub fn from_path_default(path: impl Into<PathBuf>) -> TokioByteSource {
892 Self::from_path(path, DEFAULT_CHUNK_SIZE)
893 }
894
895 #[must_use]
902 pub fn to_path(path: impl Into<PathBuf>) -> TokioByteSink {
903 tokio_file_write_sink(path.into())
904 }
905}
906
907#[cfg(feature = "io-uring-file")]
918pub struct UringFileIO;
919
920#[cfg(feature = "io-uring-file")]
921impl UringFileIO {
922 #[must_use]
929 pub fn from_path(path: impl Into<PathBuf>, chunk_size: usize) -> TokioByteSource {
930 assert!(chunk_size > 0, "chunk size must be greater than zero");
931 let path = path.into();
932 Source::from_materialized_factory(move |_materializer| {
933 let path = path.clone();
934 Ok(uring_file_read_source_or_fallback(path, chunk_size))
935 })
936 }
937
938 #[must_use]
940 pub fn from_path_default(path: impl Into<PathBuf>) -> TokioByteSource {
941 Self::from_path(path, DEFAULT_CHUNK_SIZE)
942 }
943
944 #[must_use]
949 pub fn to_path(path: impl Into<PathBuf>) -> TokioByteSink {
950 uring_file_write_sink_or_fallback(path.into())
951 }
952}
953
954#[cfg(all(feature = "io-uring-file", not(target_os = "linux")))]
955fn uring_file_read_source_or_fallback(
956 path: PathBuf,
957 chunk_size: usize,
958) -> (BoxStream<Vec<u8>>, StreamCompletion<IoResult>) {
959 tokio_file_read_source(path, chunk_size)
960}
961
962#[cfg(all(feature = "io-uring-file", not(target_os = "linux")))]
963fn uring_file_write_sink_or_fallback(path: PathBuf) -> TokioByteSink {
964 tokio_file_write_sink(path)
965}
966
967#[cfg(all(feature = "io-uring-file", target_os = "linux"))]
968type UringJob =
969 Box<dyn FnOnce() -> std::pin::Pin<Box<dyn Future<Output = ()> + 'static>> + Send + 'static>;
970
971#[cfg(all(feature = "io-uring-file", target_os = "linux"))]
972#[derive(Clone)]
973struct UringRuntimeHandle {
974 sender: mpsc::UnboundedSender<UringJob>,
975}
976
977#[cfg(all(feature = "io-uring-file", target_os = "linux"))]
978enum UringRuntimeInit {
979 Ready,
980 Failed(std::io::Error),
981}
982
983#[cfg(all(feature = "io-uring-file", target_os = "linux"))]
984fn uring_runtime_handle() -> Result<&'static UringRuntimeHandle, String> {
985 static HANDLE: std::sync::OnceLock<Result<UringRuntimeHandle, String>> =
986 std::sync::OnceLock::new();
987 HANDLE
988 .get_or_init(start_uring_runtime)
989 .as_ref()
990 .map_err(Clone::clone)
991}
992
993#[cfg(all(feature = "io-uring-file", target_os = "linux"))]
994fn start_uring_runtime() -> Result<UringRuntimeHandle, String> {
995 let (init_sender, init_receiver) = std_mpsc::sync_channel(1);
996 let (sender, receiver) = mpsc::unbounded_channel::<UringJob>();
997 thread::Builder::new()
998 .name("datum-uring-file".to_owned())
999 .spawn(
1000 move || match tokio_uring::Runtime::new(&tokio_uring::builder()) {
1001 Ok(runtime) => {
1002 let _ = init_sender.send(UringRuntimeInit::Ready);
1003 runtime.block_on(run_uring_jobs(receiver));
1004 }
1005 Err(error) => {
1006 let _ = init_sender.send(UringRuntimeInit::Failed(error));
1007 }
1008 },
1009 )
1010 .map_err(|error| error.to_string())?;
1011
1012 match init_receiver.recv() {
1013 Ok(UringRuntimeInit::Ready) => Ok(UringRuntimeHandle { sender }),
1014 Ok(UringRuntimeInit::Failed(error)) => Err(error.to_string()),
1015 Err(_) => Err("tokio-uring runtime thread exited".to_owned()),
1016 }
1017}
1018
1019#[cfg(all(feature = "io-uring-file", target_os = "linux"))]
1020async fn run_uring_jobs(mut receiver: mpsc::UnboundedReceiver<UringJob>) {
1021 while let Some(job) = receiver.recv().await {
1022 tokio_uring::spawn(job());
1023 }
1024}
1025
1026#[cfg(all(feature = "io-uring-file", target_os = "linux"))]
1027fn spawn_uring_job(job: UringJob) -> Result<(), String> {
1028 uring_runtime_handle()?
1029 .sender
1030 .send(job)
1031 .map_err(|_| "tokio-uring runtime thread exited".to_owned())
1032}
1033
1034#[cfg(all(feature = "io-uring-file", target_os = "linux"))]
1035fn uring_file_read_source_or_fallback(
1036 path: PathBuf,
1037 chunk_size: usize,
1038) -> (BoxStream<Vec<u8>>, StreamCompletion<IoResult>) {
1039 let internal_read_size = FILE_INTERNAL_READ_SIZE.max(chunk_size);
1040 let (item_sender, item_receiver) = mpsc::channel(FILE_READ_AHEAD_CHUNKS);
1041 let (cancel_sender, cancel_receiver) = watch::channel(false);
1042 let (mat_sender, mat_receiver) = oneshot::channel();
1043 let terminal = Arc::new(Mutex::new(None));
1044 let terminal_for_task = Arc::clone(&terminal);
1045 let waker = ConsumerWaker::new();
1046 let producer_waker = waker.clone();
1047 let uring_path = path.clone();
1048
1049 let result = spawn_uring_job(Box::new(move || {
1050 Box::pin(async move {
1051 let task_result = AssertUnwindSafe(run_uring_read_task(
1052 uring_path,
1053 chunk_size,
1054 internal_read_size,
1055 item_sender,
1056 cancel_receiver,
1057 Arc::clone(&terminal_for_task),
1058 producer_waker,
1059 ))
1060 .catch_unwind()
1061 .await
1062 .unwrap_or_else(|_| {
1063 finish_terminal(
1064 &terminal_for_task,
1065 DemandTerminal::Error(StreamError::AbruptTermination),
1066 );
1067 Err(StreamError::AbruptTermination)
1068 });
1069 let _ = mat_sender.send(task_result);
1070 })
1071 }));
1072
1073 if result.is_err() {
1074 return tokio_file_read_source(path, chunk_size);
1075 }
1076
1077 (
1078 Box::new(BoundedByteSourceStream {
1079 receiver: item_receiver,
1080 cancel: cancel_sender,
1081 terminal,
1082 done: false,
1083 waker,
1084 }) as BoxStream<Vec<u8>>,
1085 StreamCompletion::from_receiver(mat_receiver, None),
1086 )
1087}
1088
1089#[cfg(all(feature = "io-uring-file", target_os = "linux"))]
1090async fn run_uring_read_task(
1091 path: PathBuf,
1092 chunk_size: usize,
1093 internal_read_size: usize,
1094 items: mpsc::Sender<DemandResponse<Vec<u8>>>,
1095 mut cancel: watch::Receiver<bool>,
1096 terminal: Arc<Mutex<Option<DemandTerminal>>>,
1097 waker: ConsumerWaker,
1098) -> StreamResult<IoResult> {
1099 let mut bytes = 0_u64;
1100 let file = tokio::select! {
1101 file = tokio_uring::fs::File::open(path) => match file {
1102 Ok(file) => file,
1103 Err(error) => {
1104 let error = io_error(error);
1105 finish_terminal(&terminal, DemandTerminal::Error(error.clone()));
1106 let _ = send_read_item(&items, DemandResponse::Error(error.clone()), &mut cancel, &waker).await;
1107 return Ok(IoResult::failed(bytes, error));
1108 }
1109 },
1110 changed = cancel.changed() => {
1111 let _ = changed;
1112 finish_terminal(&terminal, DemandTerminal::Error(StreamError::Cancelled));
1113 return Ok(IoResult::failed(bytes, StreamError::Cancelled));
1114 }
1115 };
1116
1117 let mut offset = 0_u64;
1118 let mut buffer = Vec::with_capacity(internal_read_size);
1119 let mut pending_tail = Vec::with_capacity(chunk_size);
1120 loop {
1121 let read_buffer = std::mem::take(&mut buffer);
1122 let (read, returned_buffer) = tokio::select! {
1123 result = file.read_at(read_buffer, offset) => result,
1124 changed = cancel.changed() => {
1125 let _ = changed;
1126 finish_terminal(&terminal, DemandTerminal::Error(StreamError::Cancelled));
1127 return Ok(IoResult::failed(bytes, StreamError::Cancelled));
1128 }
1129 };
1130 buffer = returned_buffer;
1131
1132 match read {
1133 Ok(0) => {
1134 if !pending_tail.is_empty()
1135 && !send_read_item(
1136 &items,
1137 DemandResponse::Item(std::mem::take(&mut pending_tail)),
1138 &mut cancel,
1139 &waker,
1140 )
1141 .await
1142 {
1143 finish_terminal(&terminal, DemandTerminal::Error(StreamError::Cancelled));
1144 return Ok(IoResult::failed(bytes, StreamError::Cancelled));
1145 }
1146 finish_terminal(&terminal, DemandTerminal::Complete);
1147 let _ = send_read_item(&items, DemandResponse::Complete, &mut cancel, &waker).await;
1148 return Ok(IoResult::succeeded(bytes));
1149 }
1150 Ok(read) => {
1151 bytes += read as u64;
1152 offset += read as u64;
1153 if !send_read_chunks(
1154 &items,
1155 chunk_size,
1156 &mut pending_tail,
1157 &buffer[..read],
1158 &mut cancel,
1159 &waker,
1160 )
1161 .await
1162 {
1163 finish_terminal(&terminal, DemandTerminal::Error(StreamError::Cancelled));
1164 return Ok(IoResult::failed(bytes, StreamError::Cancelled));
1165 }
1166 buffer.clear();
1167 }
1168 Err(error) => {
1169 let error = io_error(error);
1170 finish_terminal(&terminal, DemandTerminal::Error(error.clone()));
1171 let _ = send_read_item(
1172 &items,
1173 DemandResponse::Error(error.clone()),
1174 &mut cancel,
1175 &waker,
1176 )
1177 .await;
1178 return Ok(IoResult::failed(bytes, error));
1179 }
1180 }
1181 }
1182}
1183
1184#[cfg(all(feature = "io-uring-file", target_os = "linux"))]
1185fn uring_file_write_sink_or_fallback(path: PathBuf) -> TokioByteSink {
1186 let path = Arc::new(path);
1187 Sink::from_runner(move |input, materializer| {
1188 let path = Arc::clone(&path);
1189 let (command_sender, command_receiver) = mpsc::channel(1);
1190 let (cancel_sender, cancel_receiver) = watch::channel(false);
1191 let (done_sender, done_receiver) = std_mpsc::sync_channel(1);
1192 let uring_path = Arc::clone(&path);
1193
1194 let result = spawn_uring_job(Box::new(move || {
1195 Box::pin(async move {
1196 let task_result = AssertUnwindSafe(run_uring_write_task(
1197 uring_path,
1198 command_receiver,
1199 cancel_receiver,
1200 ))
1201 .catch_unwind()
1202 .await
1203 .unwrap_or(Err(StreamError::AbruptTermination));
1204 let _ = done_sender.send(task_result);
1205 })
1206 }));
1207
1208 if result.is_err() {
1209 let fallback_path = Arc::clone(&path);
1210 return run_async_write_sink::<tokio::fs::File, _, _>(
1211 input,
1212 materializer,
1213 Arc::new(move || tokio_file_write_open(Arc::clone(&fallback_path))),
1214 );
1215 }
1216
1217 Ok(materializer.spawn_stream(move |cancelled| {
1218 let mut guard = TokioCancelGuard::new(cancel_sender.clone());
1219 let result = feed_async_writer(
1220 input,
1221 command_sender,
1222 done_receiver,
1223 cancelled,
1224 cancel_sender,
1225 );
1226 guard.disarm();
1227 result
1228 }))
1229 })
1230}
1231
1232#[cfg(all(feature = "io-uring-file", target_os = "linux"))]
1233async fn run_uring_write_task(
1234 path: Arc<PathBuf>,
1235 mut commands: mpsc::Receiver<WriteCommand>,
1236 mut cancel: watch::Receiver<bool>,
1237) -> StreamResult<IoResult> {
1238 let mut bytes = 0_u64;
1239 let file = tokio::select! {
1240 file = tokio_uring::fs::File::create(path.as_ref()) => match file {
1241 Ok(file) => file,
1242 Err(error) => return Ok(IoResult::failed(bytes, io_error(error))),
1243 },
1244 changed = cancel.changed() => {
1245 let _ = changed;
1246 return Ok(IoResult::failed(bytes, StreamError::Cancelled));
1247 }
1248 };
1249
1250 let mut offset = 0_u64;
1251 loop {
1252 let command = tokio::select! {
1253 command = commands.recv() => command,
1254 changed = cancel.changed() => {
1255 let _ = changed;
1256 return Ok(IoResult::failed(bytes, StreamError::Cancelled));
1257 }
1258 };
1259
1260 match command {
1261 Some(WriteCommand::Chunk(chunk)) => {
1262 if let Err(error) =
1263 write_uring_chunk(&file, chunk, &mut offset, &mut cancel, &mut bytes).await
1264 {
1265 return Ok(IoResult::failed(bytes, error));
1266 }
1267 }
1268 Some(WriteCommand::Finish(upstream_status)) => {
1269 let close_status = tokio::select! {
1270 result = file.close() => result.map_err(io_error),
1271 changed = cancel.changed() => {
1272 let _ = changed;
1273 Err(StreamError::Cancelled)
1274 }
1275 };
1276 return Ok(IoResult {
1277 bytes,
1278 status: upstream_status.and(close_status),
1279 });
1280 }
1281 None => {
1282 let _ = file.close().await;
1283 return Ok(IoResult::failed(bytes, StreamError::Cancelled));
1284 }
1285 }
1286 }
1287}
1288
1289#[cfg(all(feature = "io-uring-file", target_os = "linux"))]
1290async fn write_uring_chunk(
1291 file: &tokio_uring::fs::File,
1292 chunk: Vec<u8>,
1293 offset: &mut u64,
1294 cancel: &mut watch::Receiver<bool>,
1295 bytes: &mut u64,
1296) -> StreamResult<()> {
1297 use tokio_uring::buf::BoundedBuf;
1298
1299 let len = chunk.len();
1300 if len == 0 {
1301 return Ok(());
1302 }
1303
1304 let mut written = 0usize;
1305 let mut buffer = chunk;
1306 while written < len {
1307 let slice = buffer.slice(written..len);
1308 let (result, returned) = tokio::select! {
1309 result = file.write_at(slice, *offset).submit() => result,
1310 changed = cancel.changed() => {
1311 let _ = changed;
1312 return Err(StreamError::Cancelled);
1313 }
1314 };
1315 buffer = returned.into_inner();
1316
1317 match result {
1318 Ok(0) => return Err(write_zero_error()),
1319 Ok(n) => {
1320 written += n;
1321 *offset += n as u64;
1322 *bytes += n as u64;
1323 }
1324 Err(error) => return Err(io_error(error)),
1325 }
1326 }
1327
1328 Ok(())
1329}
1330
1331#[derive(Debug, Clone, Copy, PartialEq, Eq)]
1334pub struct TcpConnection {
1335 pub local_addr: SocketAddr,
1337 pub remote_addr: SocketAddr,
1339}
1340
1341impl TcpConnection {
1342 #[must_use]
1343 pub fn local_addr(&self) -> SocketAddr {
1344 self.local_addr
1345 }
1346
1347 #[must_use]
1348 pub fn remote_addr(&self) -> SocketAddr {
1349 self.remote_addr
1350 }
1351}
1352
1353#[derive(Debug, Clone, Copy, PartialEq, Eq)]
1356pub struct TcpBinding {
1357 pub local_addr: SocketAddr,
1359}
1360
1361impl TcpBinding {
1362 #[must_use]
1363 pub fn local_addr(&self) -> SocketAddr {
1364 self.local_addr
1365 }
1366}
1367
1368pub struct TcpIncomingConnection {
1375 connection: TcpConnection,
1376 source: TokioByteSource,
1377 sink: TokioByteSink,
1378}
1379
1380impl TcpIncomingConnection {
1381 #[must_use]
1382 pub fn local_addr(&self) -> SocketAddr {
1383 self.connection.local_addr
1384 }
1385
1386 #[must_use]
1387 pub fn remote_addr(&self) -> SocketAddr {
1388 self.connection.remote_addr
1389 }
1390
1391 #[must_use]
1392 pub fn connection(&self) -> TcpConnection {
1393 self.connection
1394 }
1395
1396 #[must_use]
1400 pub fn into_parts(self) -> (TokioByteSource, TokioByteSink) {
1401 (self.source, self.sink)
1402 }
1403
1404 #[must_use]
1407 pub fn into_flow(self) -> Flow<Vec<u8>, Vec<u8>, NotUsed> {
1408 Flow::from_sink_and_source_coupled(self.sink, self.source)
1409 .map_materialized_value(|_| NotUsed)
1410 }
1411}
1412
1413pub struct TokioTcp;
1414
1415impl TokioTcp {
1416 #[must_use]
1423 pub fn outgoing_connection<A>(
1424 addr: A,
1425 chunk_size: usize,
1426 ) -> Flow<Vec<u8>, Vec<u8>, StreamCompletion<TcpConnection>>
1427 where
1428 A: ToSocketAddrs + Clone + Send + Sync + 'static,
1429 {
1430 assert!(chunk_size > 0, "chunk size must be greater than zero");
1431 Flow::future_flow(move || {
1432 let addr = addr.clone();
1433 async move {
1434 let stream = TcpStream::connect(addr).await.map_err(io_error)?;
1435 Ok(tcp_flow_from_stream(stream, chunk_size))
1436 }
1437 })
1438 }
1439
1440 #[must_use]
1442 pub fn outgoing_connection_default<A>(
1443 addr: A,
1444 ) -> Flow<Vec<u8>, Vec<u8>, StreamCompletion<TcpConnection>>
1445 where
1446 A: ToSocketAddrs + Clone + Send + Sync + 'static,
1447 {
1448 Self::outgoing_connection(addr, DEFAULT_CHUNK_SIZE)
1449 }
1450
1451 #[must_use]
1458 pub fn bind<A>(
1459 addr: A,
1460 chunk_size: usize,
1461 ) -> Source<TcpIncomingConnection, StreamCompletion<TcpBinding>>
1462 where
1463 A: ToSocketAddrs + Clone + Send + Sync + 'static,
1464 {
1465 assert!(chunk_size > 0, "chunk size must be greater than zero");
1466 Source::from_materialized_factory(move |_materializer| {
1467 let (demand_sender, demand_receiver) = mpsc::channel(1);
1468 let (cancel_sender, cancel_receiver) = watch::channel(false);
1469 let (binding_sender, binding_receiver) = oneshot::channel();
1470 let terminal = Arc::new(Mutex::new(None));
1471 let terminal_for_task = Arc::clone(&terminal);
1472 let addr = addr.clone();
1473
1474 crate::stream::stream_tokio_runtime().spawn(async move {
1475 let result = AssertUnwindSafe(run_tcp_bind_task(
1476 addr,
1477 chunk_size,
1478 demand_receiver,
1479 cancel_receiver,
1480 binding_sender,
1481 Arc::clone(&terminal_for_task),
1482 ))
1483 .catch_unwind()
1484 .await;
1485 if result.is_err() {
1486 finish_terminal(
1487 &terminal_for_task,
1488 DemandTerminal::Error(StreamError::AbruptTermination),
1489 );
1490 }
1491 });
1492
1493 Ok((
1494 Box::new(DemandSourceStream {
1495 demands: demand_sender,
1496 cancel: cancel_sender,
1497 terminal,
1498 done: false,
1499 }) as BoxStream<TcpIncomingConnection>,
1500 StreamCompletion::from_receiver(binding_receiver, None),
1501 ))
1502 })
1503 }
1504
1505 #[must_use]
1507 pub fn bind_default<A>(addr: A) -> Source<TcpIncomingConnection, StreamCompletion<TcpBinding>>
1508 where
1509 A: ToSocketAddrs + Clone + Send + Sync + 'static,
1510 {
1511 Self::bind(addr, DEFAULT_CHUNK_SIZE)
1512 }
1513}
1514
1515fn tcp_flow_from_stream(
1516 stream: TcpStream,
1517 chunk_size: usize,
1518) -> Flow<Vec<u8>, Vec<u8>, TcpConnection> {
1519 let connection = TcpConnection {
1520 local_addr: stream
1521 .local_addr()
1522 .expect("connected TCP stream has local address"),
1523 remote_addr: stream
1524 .peer_addr()
1525 .expect("connected TCP stream has peer address"),
1526 };
1527 let (read_half, write_half) = stream.into_split();
1528 let source = single_use_async_read_source(read_half, chunk_size);
1529 let sink = single_use_async_write_sink(write_half);
1530 Flow::from_sink_and_source(sink, source).map_materialized_value(move |_| connection)
1537}
1538
1539fn single_use_async_read_source<R>(reader: R, chunk_size: usize) -> TokioByteSource
1540where
1541 R: AsyncRead + Unpin + Send + 'static,
1542{
1543 let reader = Arc::new(Mutex::new(Some(reader)));
1544 Source::from_materialized_factory(move |_materializer| {
1545 let reader = Arc::clone(&reader);
1546 Ok(async_read_source(
1547 move || async move {
1548 reader
1549 .lock()
1550 .expect("single-use async reader poisoned")
1551 .take()
1552 .ok_or_else(|| std::io::Error::other("async reader already materialized"))
1553 },
1554 chunk_size,
1555 chunk_size,
1556 TCP_READ_AHEAD_CHUNKS,
1557 ))
1558 })
1559}
1560
1561fn single_use_async_write_sink<W>(writer: W) -> TokioByteSink
1562where
1563 W: AsyncWrite + Unpin + Send + 'static,
1564{
1565 let writer = Arc::new(Mutex::new(Some(writer)));
1566 async_write_sink(move || {
1567 let writer = Arc::clone(&writer);
1568 async move {
1569 writer
1570 .lock()
1571 .expect("single-use async writer poisoned")
1572 .take()
1573 .ok_or_else(|| std::io::Error::other("async writer already materialized"))
1574 }
1575 })
1576}
1577
1578async fn run_tcp_bind_task<A>(
1579 addr: A,
1580 chunk_size: usize,
1581 mut demands: mpsc::Receiver<std_mpsc::Sender<DemandResponse<TcpIncomingConnection>>>,
1582 mut cancel: watch::Receiver<bool>,
1583 binding_sender: oneshot::Sender<StreamResult<TcpBinding>>,
1584 terminal: Arc<Mutex<Option<DemandTerminal>>>,
1585) where
1586 A: ToSocketAddrs + Send + 'static,
1587{
1588 let listener = match TcpListener::bind(addr).await {
1589 Ok(listener) => listener,
1590 Err(error) => {
1591 let error = io_error(error);
1592 finish_terminal(&terminal, DemandTerminal::Error(error.clone()));
1593 let _ = binding_sender.send(Err(error));
1594 return;
1595 }
1596 };
1597 let local_addr = match listener.local_addr() {
1598 Ok(local_addr) => local_addr,
1599 Err(error) => {
1600 let error = io_error(error);
1601 finish_terminal(&terminal, DemandTerminal::Error(error.clone()));
1602 let _ = binding_sender.send(Err(error));
1603 return;
1604 }
1605 };
1606 let _ = binding_sender.send(Ok(TcpBinding { local_addr }));
1607
1608 loop {
1609 let Some(reply) = next_demand(&mut demands, &mut cancel).await else {
1610 finish_terminal(&terminal, DemandTerminal::Error(StreamError::Cancelled));
1611 return;
1612 };
1613
1614 let (stream, remote_addr) = loop {
1615 let accepted = tokio::select! {
1616 accepted = listener.accept() => accepted,
1617 changed = cancel.changed() => {
1618 let _ = changed;
1619 finish_terminal(&terminal, DemandTerminal::Error(StreamError::Cancelled));
1620 return;
1621 }
1622 };
1623
1624 match accepted {
1625 Ok(accepted) => break accepted,
1626 Err(error) if is_transient_accept_error(&error) => continue,
1627 Err(error) => {
1628 let error = io_error(error);
1632 finish_terminal(&terminal, DemandTerminal::Error(error.clone()));
1633 let _ = reply.send(DemandResponse::Error(error));
1634 return;
1635 }
1636 }
1637 };
1638
1639 let incoming = tcp_incoming_connection(stream, remote_addr, local_addr, chunk_size);
1640 if reply.send(DemandResponse::Item(incoming)).is_err() {
1641 finish_terminal(&terminal, DemandTerminal::Error(StreamError::Cancelled));
1642 return;
1643 }
1644 }
1645}
1646
1647fn is_transient_accept_error(error: &std::io::Error) -> bool {
1648 matches!(
1649 error.kind(),
1650 std::io::ErrorKind::Interrupted
1651 | std::io::ErrorKind::ConnectionAborted
1652 | std::io::ErrorKind::ConnectionReset
1653 ) || error.raw_os_error().is_some_and(is_transient_accept_errno)
1654}
1655
1656#[cfg(target_os = "linux")]
1657fn is_transient_accept_errno(code: i32) -> bool {
1658 matches!(code, 4 | 103 | 104)
1659}
1660
1661#[cfg(not(target_os = "linux"))]
1662fn is_transient_accept_errno(_code: i32) -> bool {
1663 false
1664}
1665
1666fn tcp_incoming_connection(
1667 stream: TcpStream,
1668 remote_addr: SocketAddr,
1669 local_addr: SocketAddr,
1670 chunk_size: usize,
1671) -> TcpIncomingConnection {
1672 let connection = TcpConnection {
1673 local_addr,
1674 remote_addr,
1675 };
1676 let (read_half, write_half) = stream.into_split();
1677 let source = single_use_async_read_source(read_half, chunk_size);
1678 let sink = single_use_async_write_sink(write_half);
1679 TcpIncomingConnection {
1680 connection,
1681 source,
1682 sink,
1683 }
1684}
1685
1686#[cfg(test)]
1687mod tests {
1688 use super::*;
1689 use crate::testkit::TestSink;
1690 use crate::{Framing, Keep, Sink, Source};
1691 use std::pin::Pin;
1692 use std::sync::atomic::{AtomicBool as StdAtomicBool, Ordering as StdOrdering};
1693 use std::task::{Context, Poll};
1694 use std::time::{Duration, Instant, SystemTime, UNIX_EPOCH};
1695
1696 fn unique_temp_path(name: &str) -> PathBuf {
1697 let nanos = SystemTime::now()
1698 .duration_since(UNIX_EPOCH)
1699 .expect("clock after epoch")
1700 .as_nanos();
1701 std::env::temp_dir().join(format!(
1702 "datum-wp12b-{name}-{}-{nanos}.bin",
1703 std::process::id()
1704 ))
1705 }
1706
1707 fn wait_until(timeout: Duration, condition: impl Fn() -> bool) -> bool {
1708 let deadline = Instant::now() + timeout;
1709 while Instant::now() < deadline {
1710 if condition() {
1711 return true;
1712 }
1713 thread::sleep(Duration::from_millis(5));
1714 }
1715 condition()
1716 }
1717
1718 struct PendingWriter {
1719 polled: Arc<StdAtomicBool>,
1720 dropped: Arc<StdAtomicBool>,
1721 }
1722
1723 impl AsyncWrite for PendingWriter {
1724 fn poll_write(
1725 self: Pin<&mut Self>,
1726 _cx: &mut Context<'_>,
1727 _buf: &[u8],
1728 ) -> Poll<std::io::Result<usize>> {
1729 self.polled.store(true, StdOrdering::SeqCst);
1730 Poll::Pending
1731 }
1732
1733 fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
1734 self.polled.store(true, StdOrdering::SeqCst);
1735 Poll::Pending
1736 }
1737
1738 fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
1739 self.polled.store(true, StdOrdering::SeqCst);
1740 Poll::Pending
1741 }
1742 }
1743
1744 impl Drop for PendingWriter {
1745 fn drop(&mut self) {
1746 self.dropped.store(true, StdOrdering::SeqCst);
1747 }
1748 }
1749
1750 #[test]
1751 fn tokio_file_io_round_trips_bytes_and_reports_counts() {
1752 let path = unique_temp_path("roundtrip");
1753 let write_completion = Source::from_iter([b"ab".to_vec(), b"cd".to_vec()])
1754 .run_with(TokioFileIO::to_path(path.clone()))
1755 .expect("tokio file sink materializes");
1756 let write_result = write_completion.wait().expect("tokio file write completes");
1757 assert_eq!(write_result.bytes(), 4);
1758 assert_eq!(write_result.status(), Ok(()));
1759
1760 let (read_completion, collected) = TokioFileIO::from_path(path.clone(), 2)
1761 .to_mat(Sink::collect(), Keep::both)
1762 .run()
1763 .expect("tokio file source materializes");
1764 assert_eq!(
1765 collected.wait().expect("collect completes"),
1766 vec![b"ab".to_vec(), b"cd".to_vec()]
1767 );
1768 let read_result = read_completion.wait().expect("read completion available");
1769 assert_eq!(read_result.bytes(), 4);
1770 assert_eq!(read_result.status(), Ok(()));
1771
1772 std::fs::remove_file(path).expect("remove roundtrip file");
1773 }
1774
1775 #[test]
1776 fn tokio_file_source_surfaces_open_failure() {
1777 let missing = unique_temp_path("missing");
1778 let (read_completion, collected) = TokioFileIO::from_path(missing, 4)
1779 .to_mat(Sink::collect(), Keep::both)
1780 .run()
1781 .expect("tokio file source materializes despite open failure");
1782 let stream_error = collected.wait().expect_err("collect fails");
1783 assert!(matches!(stream_error, StreamError::Failed(_)));
1784 let read_result = read_completion.wait().expect("io result available");
1785 assert_eq!(read_result.bytes(), 0);
1786 assert!(matches!(read_result.status(), Err(StreamError::Failed(_))));
1787 }
1788
1789 #[test]
1790 fn tokio_file_source_composes_with_framing_and_sink() {
1791 let path = unique_temp_path("framing");
1792 std::fs::write(&path, b"alpha\nbeta\ngamma\n").expect("write framed seed file");
1793
1794 let frames = TokioFileIO::from_path(path.clone(), 5)
1795 .via(Framing::delimiter(b"\n".to_vec(), 64, true))
1796 .run_with(Sink::collect())
1797 .expect("framed file stream materializes")
1798 .wait()
1799 .expect("framed file stream completes");
1800
1801 assert_eq!(
1802 frames,
1803 vec![b"alpha".to_vec(), b"beta".to_vec(), b"gamma".to_vec()]
1804 );
1805 std::fs::remove_file(path).expect("remove framed file");
1806 }
1807
1808 #[test]
1809 fn tokio_file_source_preserves_requested_chunk_boundaries() {
1810 let path = unique_temp_path("chunk-boundaries");
1811 let chunk_size = 8192;
1812 let tail_size = 13;
1813 let data_len = FILE_INTERNAL_READ_SIZE + tail_size;
1814 let data: Vec<u8> = (0..data_len).map(|index| (index % 251) as u8).collect();
1815 std::fs::write(&path, &data).expect("write chunk boundary seed file");
1816
1817 let (read_completion, chunks) = TokioFileIO::from_path(path.clone(), chunk_size)
1818 .to_mat(Sink::collect(), Keep::both)
1819 .run()
1820 .expect("tokio file source materializes");
1821 let chunks = chunks.wait().expect("chunk boundary stream completes");
1822
1823 assert!(chunks.len() > 1);
1824 for chunk in &chunks[..chunks.len() - 1] {
1825 assert_eq!(chunk.len(), chunk_size);
1826 }
1827 assert_eq!(chunks.last().expect("tail chunk exists").len(), tail_size);
1828 let reassembled: Vec<u8> = chunks.into_iter().flatten().collect();
1829 assert_eq!(reassembled, data);
1830 assert_eq!(
1831 read_completion
1832 .wait()
1833 .expect("read completion available")
1834 .bytes(),
1835 data_len as u64
1836 );
1837
1838 std::fs::remove_file(path).expect("remove chunk boundary file");
1839 }
1840
1841 #[cfg(all(feature = "io-uring-file", target_os = "linux"))]
1842 #[test]
1843 fn uring_file_io_round_trips_bytes_and_reports_counts() {
1844 let path = unique_temp_path("uring-roundtrip");
1845 let write_completion = Source::from_iter([b"ab".to_vec(), b"cd".to_vec()])
1846 .run_with(UringFileIO::to_path(path.clone()))
1847 .expect("uring file sink materializes");
1848 let write_result = write_completion.wait().expect("uring file write completes");
1849 assert_eq!(write_result.bytes(), 4);
1850 assert_eq!(write_result.status(), Ok(()));
1851
1852 let (read_completion, collected) = UringFileIO::from_path(path.clone(), 2)
1853 .to_mat(Sink::collect(), Keep::both)
1854 .run()
1855 .expect("uring file source materializes");
1856 assert_eq!(
1857 collected.wait().expect("collect completes"),
1858 vec![b"ab".to_vec(), b"cd".to_vec()]
1859 );
1860 let read_result = read_completion.wait().expect("read completion available");
1861 assert_eq!(read_result.bytes(), 4);
1862 assert_eq!(read_result.status(), Ok(()));
1863
1864 std::fs::remove_file(path).expect("remove uring roundtrip file");
1865 }
1866
1867 #[cfg(all(feature = "io-uring-file", target_os = "linux"))]
1868 #[test]
1869 fn uring_file_source_surfaces_open_failure() {
1870 let missing = unique_temp_path("uring-missing");
1871 let (read_completion, collected) = UringFileIO::from_path(missing, 4)
1872 .to_mat(Sink::collect(), Keep::both)
1873 .run()
1874 .expect("uring file source materializes despite open failure");
1875 let stream_error = collected.wait().expect_err("collect fails");
1876 assert!(matches!(stream_error, StreamError::Failed(_)));
1877 let read_result = read_completion.wait().expect("io result available");
1878 assert_eq!(read_result.bytes(), 0);
1879 assert!(matches!(read_result.status(), Err(StreamError::Failed(_))));
1880 }
1881
1882 #[cfg(all(feature = "io-uring-file", target_os = "linux"))]
1883 #[test]
1884 fn uring_file_sink_surfaces_open_failure() {
1885 let path = unique_temp_path("uring-sink-dir");
1886 std::fs::create_dir(&path).expect("create sink failure directory");
1887
1888 let completion = Source::single(b"blocked".to_vec())
1889 .run_with(UringFileIO::to_path(path.clone()))
1890 .expect("uring file sink materializes despite open failure");
1891 let result = completion.wait().expect("uring sink result available");
1892 assert_eq!(result.bytes(), 0);
1893 assert!(matches!(result.status(), Err(StreamError::Failed(_))));
1894
1895 std::fs::remove_dir(path).expect("remove sink failure directory");
1896 }
1897
1898 #[cfg(all(feature = "io-uring-file", target_os = "linux"))]
1899 #[test]
1900 fn uring_file_source_preserves_requested_chunk_boundaries() {
1901 let path = unique_temp_path("uring-chunk-boundaries");
1902 let chunk_size = 8192;
1903 let tail_size = 13;
1904 let data_len = FILE_INTERNAL_READ_SIZE + tail_size;
1905 let data: Vec<u8> = (0..data_len).map(|index| (index % 251) as u8).collect();
1906 std::fs::write(&path, &data).expect("write uring chunk boundary seed file");
1907
1908 let (read_completion, chunks) = UringFileIO::from_path(path.clone(), chunk_size)
1909 .to_mat(Sink::collect(), Keep::both)
1910 .run()
1911 .expect("uring file source materializes");
1912 let chunks = chunks.wait().expect("chunk boundary stream completes");
1913
1914 assert!(chunks.len() > 1);
1915 for chunk in &chunks[..chunks.len() - 1] {
1916 assert_eq!(chunk.len(), chunk_size);
1917 }
1918 assert_eq!(chunks.last().expect("tail chunk exists").len(), tail_size);
1919 let reassembled: Vec<u8> = chunks.into_iter().flatten().collect();
1920 assert_eq!(reassembled, data);
1921 assert_eq!(
1922 read_completion
1923 .wait()
1924 .expect("read completion available")
1925 .bytes(),
1926 data_len as u64
1927 );
1928
1929 std::fs::remove_file(path).expect("remove uring chunk boundary file");
1930 }
1931
1932 #[cfg(all(feature = "io-uring-file", target_os = "linux"))]
1933 #[test]
1934 fn uring_file_source_cancellation_reports_cancelled() {
1935 let path = unique_temp_path("uring-cancel");
1936 let data_len = FILE_INTERNAL_READ_SIZE * 4;
1937 std::fs::write(&path, vec![b'z'; data_len]).expect("write uring cancel seed file");
1938
1939 let (completion, ignored) = UringFileIO::from_path(path.clone(), 8192)
1940 .take(1)
1941 .to_mat(Sink::ignore(), Keep::both)
1942 .run()
1943 .expect("uring cancellation source materializes");
1944 ignored.wait().expect("downstream ignore completes");
1945 let result = completion.wait().expect("uring read completion available");
1946
1947 assert!(result.bytes() > 0);
1948 assert!(matches!(result.status(), Err(StreamError::Cancelled)));
1949
1950 std::fs::remove_file(path).expect("remove uring cancel file");
1951 }
1952
1953 #[test]
1954 fn tokio_sink_cancellation_unblocks_pending_writer_completion_wait() {
1955 let polled = Arc::new(StdAtomicBool::new(false));
1956 let dropped = Arc::new(StdAtomicBool::new(false));
1957 let completion = Source::single(b"blocked".to_vec())
1958 .run_with(async_write_sink({
1959 let polled = Arc::clone(&polled);
1960 let dropped = Arc::clone(&dropped);
1961 move || {
1962 let polled = Arc::clone(&polled);
1963 let dropped = Arc::clone(&dropped);
1964 async move { Ok(PendingWriter { polled, dropped }) }
1965 }
1966 }))
1967 .expect("pending writer sink materializes");
1968
1969 assert!(wait_until(Duration::from_secs(1), || {
1970 polled.load(StdOrdering::SeqCst)
1971 }));
1972 drop(completion);
1973 assert!(wait_until(Duration::from_secs(1), || {
1974 dropped.load(StdOrdering::SeqCst)
1975 }));
1976 }
1977
1978 #[test]
1979 fn tokio_tcp_accept_error_classifier_retries_only_connection_races() {
1980 assert!(is_transient_accept_error(&std::io::Error::new(
1981 std::io::ErrorKind::Interrupted,
1982 "interrupted"
1983 )));
1984 assert!(is_transient_accept_error(&std::io::Error::new(
1985 std::io::ErrorKind::ConnectionAborted,
1986 "aborted before accept"
1987 )));
1988 assert!(is_transient_accept_error(&std::io::Error::new(
1989 std::io::ErrorKind::ConnectionReset,
1990 "reset before accept"
1991 )));
1992 assert!(!is_transient_accept_error(&std::io::Error::other(
1993 "fd pressure"
1994 )));
1995 }
1996
1997 #[test]
1998 fn tokio_source_cancellation_observed_promptly_under_wake_on_send() {
1999 let cancelled_detected = Arc::new(AtomicBool::new(false));
2012 let detected = Arc::clone(&cancelled_detected);
2013
2014 let source = Source::from_materialized_factory(move |_materializer| {
2015 let d = Arc::clone(&detected);
2016 let mut i = 0_u64;
2017 let stream: BoxStream<Vec<u8>> = Box::new(std::iter::from_fn(move || {
2018 let is_cancelled = || {
2021 crate::stream::current_stream_cancelled()
2022 .as_ref()
2023 .is_some_and(|c| c.load(Ordering::SeqCst))
2024 };
2025 if is_cancelled() {
2026 d.store(true, Ordering::SeqCst);
2027 return Some(Err(StreamError::Cancelled));
2028 }
2029 if i < 4 {
2030 i += 1;
2031 return Some(Ok(vec![0_u8; 8192]));
2032 }
2033 std::thread::park_timeout(Duration::from_millis(50));
2036 if is_cancelled() {
2037 d.store(true, Ordering::SeqCst);
2038 return Some(Err(StreamError::Cancelled));
2039 }
2040 Some(Ok(vec![0_u8]))
2043 }));
2044 Ok((stream, NotUsed))
2045 });
2046
2047 let completion = source
2048 .run_with(Sink::ignore())
2049 .expect("cancellation source materializes");
2050
2051 let cancel_thread = thread::spawn(move || {
2052 thread::sleep(Duration::from_millis(20));
2053 drop(completion);
2054 });
2055
2056 let deadline = Instant::now() + Duration::from_secs(3);
2058 loop {
2059 if cancelled_detected.load(Ordering::SeqCst) {
2060 break;
2061 }
2062 assert!(
2063 Instant::now() < deadline,
2064 "cancellation not observed within 3 s; the park‑wake path may be broken"
2065 );
2066 thread::park_timeout(Duration::from_millis(10));
2067 }
2068 cancel_thread.join().expect("cancellation thread joins");
2069 }
2070
2071 #[test]
2072 fn tokio_tcp_bind_and_outgoing_connection_echo_round_trip() {
2073 let (binding_completion, incoming_completion) = TokioTcp::bind("127.0.0.1:0", 1024)
2074 .to_mat(Sink::head(), Keep::both)
2075 .run()
2076 .expect("tcp bind source materializes");
2077 let binding = binding_completion.wait().expect("tcp binding succeeds");
2078
2079 let client_completion = Source::single(b"ping".to_vec())
2080 .via(TokioTcp::outgoing_connection(binding.local_addr(), 1024))
2081 .run_with(Sink::head())
2082 .expect("client stream materializes");
2083
2084 let incoming = incoming_completion
2085 .wait()
2086 .expect("incoming connection accepted");
2087 let (incoming_source, incoming_sink) = incoming.into_parts();
2088 let server_read = incoming_source
2089 .run_with(Sink::head())
2090 .expect("server read materializes")
2091 .wait()
2092 .expect("server reads request");
2093 assert_eq!(server_read, b"ping".to_vec());
2094
2095 let server_write = Source::single(server_read)
2096 .run_with(incoming_sink)
2097 .expect("server write materializes");
2098 let write_result = server_write.wait().expect("server write completes");
2099 assert_eq!(write_result.bytes(), 4);
2100 assert_eq!(write_result.status(), Ok(()));
2101
2102 assert_eq!(
2103 client_completion.wait().expect("client receives echo"),
2104 b"ping".to_vec()
2105 );
2106 }
2107
2108 #[test]
2109 fn tokio_tcp_bind_accepts_multiple_connections_on_successive_demands() {
2110 let (binding_completion, mut incoming_probe) = TokioTcp::bind("127.0.0.1:0", 1024)
2111 .to_mat(TestSink::probe(), Keep::both)
2112 .run()
2113 .expect("tcp bind source materializes");
2114 let binding = binding_completion.wait().expect("tcp binding succeeds");
2115 incoming_probe.set_timeout(Duration::from_secs(30));
2120
2121 incoming_probe.request(1);
2122 let first_client =
2123 std::net::TcpStream::connect(binding.local_addr()).expect("first client connects");
2124 let first = incoming_probe.expect_next();
2125 assert_eq!(first.local_addr(), binding.local_addr());
2126 assert_eq!(
2127 first.remote_addr(),
2128 first_client.local_addr().expect("first client local addr")
2129 );
2130
2131 incoming_probe.request(1);
2132 let second_client =
2133 std::net::TcpStream::connect(binding.local_addr()).expect("second client connects");
2134 let second = incoming_probe.expect_next();
2135 assert_eq!(second.local_addr(), binding.local_addr());
2136 assert_eq!(
2137 second.remote_addr(),
2138 second_client
2139 .local_addr()
2140 .expect("second client local addr")
2141 );
2142 assert_ne!(first.remote_addr(), second.remote_addr());
2143
2144 incoming_probe.cancel();
2145 }
2146}