1use crate::stream::{BoxStream, Flow, NotUsed, Sink, Source, StreamCompletion};
2use crate::{StreamError, StreamResult};
3use futures::{FutureExt, channel::oneshot};
4use std::future::Future;
5use std::net::SocketAddr;
6use std::panic::AssertUnwindSafe;
7use std::path::PathBuf;
8use std::sync::{
9 Arc, Mutex,
10 atomic::{AtomicBool, Ordering},
11 mpsc as std_mpsc,
12};
13use std::thread::{self, Thread};
14use std::time::Duration;
15use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
16use tokio::net::{TcpListener, TcpStream, ToSocketAddrs};
17use tokio::sync::{mpsc, watch};
18
19const DEFAULT_CHUNK_SIZE: usize = 8192;
20const FILE_READ_AHEAD_CHUNKS: usize = 8;
21const FILE_INTERNAL_READ_SIZE: usize = 256 * 1024;
22const TCP_READ_AHEAD_CHUNKS: usize = 1;
23const PARK_INTERVAL: Duration = Duration::from_millis(1);
24const READ_READY_SPINS: usize = 256;
25const BACKPRESSURE_READY_SPINS: usize = 64;
26const BACKPRESSURE_PARK: Duration = Duration::from_micros(10);
27
28#[derive(Clone)]
36struct ConsumerWaker {
37 thread: Arc<Mutex<Option<Thread>>>,
38}
39
40impl ConsumerWaker {
41 fn new() -> Self {
42 Self {
43 thread: Arc::new(Mutex::new(None)),
44 }
45 }
46
47 fn capture_current(&self) {
48 let mut slot = self.thread.lock().expect("consumer waker poisoned");
49 if slot.is_none() {
50 *slot = Some(thread::current());
51 }
52 }
53
54 fn unpark(&self) {
55 let slot = self.thread.lock().expect("consumer waker poisoned");
56 if let Some(t) = slot.as_ref() {
57 t.unpark();
58 }
59 }
60}
61
62fn io_error(error: std::io::Error) -> StreamError {
63 StreamError::Failed(error.to_string())
64}
65
66fn write_zero_error() -> StreamError {
67 StreamError::Failed("async writer returned zero bytes".to_owned())
68}
69
70#[derive(Debug, Clone, PartialEq, Eq)]
77pub struct IoResult {
78 pub bytes: u64,
79 pub status: StreamResult<()>,
80}
81
82impl IoResult {
83 #[must_use]
84 pub fn succeeded(bytes: u64) -> Self {
85 Self {
86 bytes,
87 status: Ok(()),
88 }
89 }
90
91 #[must_use]
92 pub fn failed(bytes: u64, error: StreamError) -> Self {
93 Self {
94 bytes,
95 status: Err(error),
96 }
97 }
98
99 #[must_use]
100 pub fn bytes(&self) -> u64 {
101 self.bytes
102 }
103
104 pub fn status(&self) -> StreamResult<()> {
105 self.status.clone()
106 }
107
108 #[must_use]
109 pub fn is_success(&self) -> bool {
110 self.status.is_ok()
111 }
112}
113
114pub type TokioByteSource = Source<Vec<u8>, StreamCompletion<IoResult>>;
115pub type TokioByteSink = Sink<Vec<u8>, StreamCompletion<IoResult>>;
116
117#[derive(Clone)]
118enum DemandTerminal {
119 Complete,
120 Error(StreamError),
121}
122
123enum DemandResponse<T> {
124 Item(T),
125 Complete,
126 Error(StreamError),
127}
128
129struct DemandSourceStream<T> {
130 demands: mpsc::Sender<std_mpsc::Sender<DemandResponse<T>>>,
131 cancel: watch::Sender<bool>,
132 terminal: Arc<Mutex<Option<DemandTerminal>>>,
133 done: bool,
134}
135
136impl<T> DemandSourceStream<T> {
137 fn terminal_response(&self) -> Option<Option<StreamResult<T>>> {
138 self.terminal
139 .lock()
140 .expect("tokio source terminal poisoned")
141 .clone()
142 .map(|terminal| match terminal {
143 DemandTerminal::Complete => None,
144 DemandTerminal::Error(error) => Some(Err(error)),
145 })
146 }
147
148 fn mark_done(&mut self) {
149 self.done = true;
150 let _ = self.cancel.send(true);
151 }
152}
153
154impl<T: Send + 'static> Iterator for DemandSourceStream<T> {
155 type Item = StreamResult<T>;
156
157 fn next(&mut self) -> Option<Self::Item> {
158 if self.done {
159 return None;
160 }
161
162 let stream_cancelled = crate::stream::current_stream_cancelled();
163 let (reply_sender, reply_receiver) = std_mpsc::channel();
164 if !send_bounded_demand(&self.demands, reply_sender, &stream_cancelled) {
165 self.mark_done();
166 return self
167 .terminal_response()
168 .unwrap_or(Some(Err(StreamError::AbruptTermination)));
169 }
170
171 loop {
172 if stream_cancelled
173 .as_ref()
174 .is_some_and(|cancelled| cancelled.load(Ordering::SeqCst))
175 {
176 self.mark_done();
177 return Some(Err(StreamError::Cancelled));
178 }
179
180 match reply_receiver.recv_timeout(PARK_INTERVAL) {
181 Ok(DemandResponse::Item(item)) => return Some(Ok(item)),
182 Ok(DemandResponse::Complete) => {
183 self.mark_done();
184 return None;
185 }
186 Ok(DemandResponse::Error(error)) => {
187 self.mark_done();
188 return Some(Err(error));
189 }
190 Err(std_mpsc::RecvTimeoutError::Timeout) => {}
191 Err(std_mpsc::RecvTimeoutError::Disconnected) => {
192 self.mark_done();
193 return self
194 .terminal_response()
195 .unwrap_or(Some(Err(StreamError::AbruptTermination)));
196 }
197 }
198 }
199 }
200}
201
202impl<T> Drop for DemandSourceStream<T> {
203 fn drop(&mut self) {
204 let _ = self.cancel.send(true);
205 }
206}
207
208struct BoundedByteSourceStream {
209 receiver: mpsc::Receiver<DemandResponse<Vec<u8>>>,
210 cancel: watch::Sender<bool>,
211 terminal: Arc<Mutex<Option<DemandTerminal>>>,
212 done: bool,
213 waker: ConsumerWaker,
214}
215
216impl BoundedByteSourceStream {
217 fn terminal_response(&self) -> Option<Option<StreamResult<Vec<u8>>>> {
218 self.terminal
219 .lock()
220 .expect("tokio source terminal poisoned")
221 .clone()
222 .map(|terminal| match terminal {
223 DemandTerminal::Complete => None,
224 DemandTerminal::Error(error) => Some(Err(error)),
225 })
226 }
227
228 fn mark_done(&mut self) {
229 self.done = true;
230 let _ = self.cancel.send(true);
231 }
232}
233
234impl Iterator for BoundedByteSourceStream {
235 type Item = StreamResult<Vec<u8>>;
236
237 fn next(&mut self) -> Option<Self::Item> {
238 if self.done {
239 return None;
240 }
241
242 self.waker.capture_current();
246
247 let stream_cancelled = crate::stream::current_stream_cancelled();
248 let mut spins = 0usize;
249 loop {
250 if stream_cancelled
251 .as_ref()
252 .is_some_and(|cancelled| cancelled.load(Ordering::SeqCst))
253 {
254 self.mark_done();
255 return Some(Err(StreamError::Cancelled));
256 }
257
258 match self.receiver.try_recv() {
259 Ok(DemandResponse::Item(item)) => return Some(Ok(item)),
260 Ok(DemandResponse::Complete) => {
261 self.mark_done();
262 return None;
263 }
264 Ok(DemandResponse::Error(error)) => {
265 self.mark_done();
266 return Some(Err(error));
267 }
268 Err(mpsc::error::TryRecvError::Empty) => read_wait(&mut spins),
269 Err(mpsc::error::TryRecvError::Disconnected) => {
270 self.mark_done();
271 return self
272 .terminal_response()
273 .unwrap_or(Some(Err(StreamError::AbruptTermination)));
274 }
275 }
276 }
277 }
278}
279
280impl Drop for BoundedByteSourceStream {
281 fn drop(&mut self) {
282 let _ = self.cancel.send(true);
283 }
284}
285
286fn send_bounded_demand<T>(
287 sender: &mpsc::Sender<T>,
288 mut message: T,
289 stream_cancelled: &Option<Arc<AtomicBool>>,
290) -> bool {
291 let mut spins = 0usize;
292 loop {
293 if stream_cancelled
294 .as_ref()
295 .is_some_and(|cancelled| cancelled.load(Ordering::SeqCst))
296 {
297 return false;
298 }
299
300 match sender.try_send(message) {
301 Ok(()) => return true,
302 Err(mpsc::error::TrySendError::Full(returned)) => {
303 message = returned;
304 backpressure_wait(&mut spins);
305 }
306 Err(mpsc::error::TrySendError::Closed(_)) => return false,
307 }
308 }
309}
310
311fn finish_terminal(terminal: &Arc<Mutex<Option<DemandTerminal>>>, value: DemandTerminal) {
312 let mut slot = terminal.lock().expect("tokio source terminal poisoned");
313 if slot.is_none() {
314 *slot = Some(value);
315 }
316}
317
318async fn next_demand<T>(
319 demands: &mut mpsc::Receiver<std_mpsc::Sender<DemandResponse<T>>>,
320 cancel: &mut watch::Receiver<bool>,
321) -> Option<std_mpsc::Sender<DemandResponse<T>>> {
322 if *cancel.borrow() {
323 return None;
324 }
325
326 tokio::select! {
327 demand = demands.recv() => demand,
328 changed = cancel.changed() => {
329 let _ = changed;
330 None
331 }
332 }
333}
334
335fn async_read_source<R, Fut>(
336 open: impl FnOnce() -> Fut + Send + 'static,
337 chunk_size: usize,
338 internal_read_size: usize,
339 read_ahead_chunks: usize,
340) -> (BoxStream<Vec<u8>>, StreamCompletion<IoResult>)
341where
342 R: AsyncRead + Unpin + Send + 'static,
343 Fut: Future<Output = std::io::Result<R>> + Send + 'static,
344{
345 assert!(chunk_size > 0, "chunk size must be greater than zero");
346 assert!(
347 read_ahead_chunks > 0,
348 "read-ahead bound must be greater than zero"
349 );
350 let internal_read_size = internal_read_size.max(chunk_size);
351 let (item_sender, item_receiver) = mpsc::channel(read_ahead_chunks);
352 let (cancel_sender, cancel_receiver) = watch::channel(false);
353 let (mat_sender, mat_receiver) = oneshot::channel();
354 let terminal = Arc::new(Mutex::new(None));
355 let terminal_for_task = Arc::clone(&terminal);
356 let waker = ConsumerWaker::new();
357 let producer_waker = waker.clone();
358
359 crate::stream::stream_tokio_runtime().spawn(async move {
360 let result = AssertUnwindSafe(run_async_read_task(
361 open(),
362 chunk_size,
363 internal_read_size,
364 item_sender,
365 cancel_receiver,
366 Arc::clone(&terminal_for_task),
367 producer_waker,
368 ))
369 .catch_unwind()
370 .await
371 .unwrap_or_else(|_| {
372 finish_terminal(
373 &terminal_for_task,
374 DemandTerminal::Error(StreamError::AbruptTermination),
375 );
376 Err(StreamError::AbruptTermination)
377 });
378 let _ = mat_sender.send(result);
379 });
380
381 (
382 Box::new(BoundedByteSourceStream {
383 receiver: item_receiver,
384 cancel: cancel_sender,
385 terminal,
386 done: false,
387 waker,
388 }) as BoxStream<Vec<u8>>,
389 StreamCompletion::from_receiver(mat_receiver, None),
390 )
391}
392
393async fn run_async_read_task<R, Fut>(
394 open: Fut,
395 chunk_size: usize,
396 internal_read_size: usize,
397 items: mpsc::Sender<DemandResponse<Vec<u8>>>,
398 mut cancel: watch::Receiver<bool>,
399 terminal: Arc<Mutex<Option<DemandTerminal>>>,
400 waker: ConsumerWaker,
401) -> StreamResult<IoResult>
402where
403 R: AsyncRead + Unpin + Send + 'static,
404 Fut: Future<Output = std::io::Result<R>> + Send + 'static,
405{
406 let mut bytes = 0_u64;
407 let mut reader = tokio::select! {
408 reader = open => match reader {
409 Ok(reader) => reader,
410 Err(error) => {
411 let error = io_error(error);
412 finish_terminal(&terminal, DemandTerminal::Error(error.clone()));
413 let _ = send_read_item(&items, DemandResponse::Error(error.clone()), &mut cancel, &waker).await;
414 return Ok(IoResult::failed(bytes, error));
415 }
416 },
417 changed = cancel.changed() => {
418 let _ = changed;
419 finish_terminal(&terminal, DemandTerminal::Error(StreamError::Cancelled));
420 return Ok(IoResult::failed(bytes, StreamError::Cancelled));
421 }
422 };
423
424 let mut buffer = vec![0_u8; internal_read_size];
425 let mut pending_tail = Vec::with_capacity(chunk_size);
426 loop {
427 let read = tokio::select! {
428 read = reader.read(&mut buffer) => read,
429 changed = cancel.changed() => {
430 let _ = changed;
431 finish_terminal(&terminal, DemandTerminal::Error(StreamError::Cancelled));
432 return Ok(IoResult::failed(bytes, StreamError::Cancelled));
433 }
434 };
435
436 match read {
437 Ok(0) => {
438 if !pending_tail.is_empty()
439 && !send_read_item(
440 &items,
441 DemandResponse::Item(std::mem::take(&mut pending_tail)),
442 &mut cancel,
443 &waker,
444 )
445 .await
446 {
447 finish_terminal(&terminal, DemandTerminal::Error(StreamError::Cancelled));
448 return Ok(IoResult::failed(bytes, StreamError::Cancelled));
449 }
450 finish_terminal(&terminal, DemandTerminal::Complete);
451 let _ = send_read_item(&items, DemandResponse::Complete, &mut cancel, &waker).await;
452 return Ok(IoResult::succeeded(bytes));
453 }
454 Ok(read) => {
455 bytes += read as u64;
456 if !send_read_chunks(
457 &items,
458 chunk_size,
459 &mut pending_tail,
460 &buffer[..read],
461 &mut cancel,
462 &waker,
463 )
464 .await
465 {
466 finish_terminal(&terminal, DemandTerminal::Error(StreamError::Cancelled));
467 return Ok(IoResult::failed(bytes, StreamError::Cancelled));
468 }
469 }
470 Err(error) => {
471 let error = io_error(error);
472 finish_terminal(&terminal, DemandTerminal::Error(error.clone()));
473 let _ = send_read_item(
474 &items,
475 DemandResponse::Error(error.clone()),
476 &mut cancel,
477 &waker,
478 )
479 .await;
480 return Ok(IoResult::failed(bytes, error));
481 }
482 }
483 }
484}
485
486async fn send_read_chunks(
487 sender: &mpsc::Sender<DemandResponse<Vec<u8>>>,
488 chunk_size: usize,
489 pending_tail: &mut Vec<u8>,
490 read_buffer: &[u8],
491 cancel: &mut watch::Receiver<bool>,
492 waker: &ConsumerWaker,
493) -> bool {
494 let mut offset = 0;
495 if !pending_tail.is_empty() {
496 let needed = chunk_size - pending_tail.len();
497 let take = needed.min(read_buffer.len());
498 pending_tail.extend_from_slice(&read_buffer[..take]);
499 offset += take;
500 if pending_tail.len() == chunk_size
501 && !send_read_item(
502 sender,
503 DemandResponse::Item(std::mem::take(pending_tail)),
504 cancel,
505 waker,
506 )
507 .await
508 {
509 return false;
510 }
511 }
512
513 while offset + chunk_size <= read_buffer.len() {
514 let next = offset + chunk_size;
515 if !send_read_item(
516 sender,
517 DemandResponse::Item(read_buffer[offset..next].to_vec()),
518 cancel,
519 waker,
520 )
521 .await
522 {
523 return false;
524 }
525 offset = next;
526 }
527
528 if offset < read_buffer.len() {
529 pending_tail.extend_from_slice(&read_buffer[offset..]);
530 }
531 true
532}
533
534async fn send_read_item<T>(
535 sender: &mpsc::Sender<DemandResponse<T>>,
536 item: DemandResponse<T>,
537 cancel: &mut watch::Receiver<bool>,
538 waker: &ConsumerWaker,
539) -> bool
540where
541 T: Send + 'static,
542{
543 let result = tokio::select! {
544 result = sender.send(item) => result,
545 changed = cancel.changed() => {
546 let _ = changed;
547 return false;
548 }
549 };
550 if result.is_ok() {
551 waker.unpark();
552 }
553 result.is_ok()
554}
555
556enum WriteCommand {
557 Chunk(Vec<u8>),
558 Finish(StreamResult<()>),
559}
560
561struct TokioCancelGuard {
562 cancel: watch::Sender<bool>,
563 armed: bool,
564}
565
566impl TokioCancelGuard {
567 fn new(cancel: watch::Sender<bool>) -> Self {
568 Self {
569 cancel,
570 armed: true,
571 }
572 }
573
574 fn disarm(&mut self) {
575 self.armed = false;
576 }
577}
578
579impl Drop for TokioCancelGuard {
580 fn drop(&mut self) {
581 if self.armed {
582 let _ = self.cancel.send(true);
583 }
584 }
585}
586
587fn async_write_sink<W, F, Fut>(open: F) -> TokioByteSink
588where
589 W: AsyncWrite + Unpin + Send + 'static,
590 F: Fn() -> Fut + Send + Sync + 'static,
591 Fut: Future<Output = std::io::Result<W>> + Send + 'static,
592{
593 let open = Arc::new(open);
594 Sink::from_runner(move |input, materializer| {
595 run_async_write_sink::<W, F, Fut>(input, materializer, Arc::clone(&open))
596 })
597}
598
599async fn run_async_write_task<W, Fut>(
600 open: Fut,
601 mut commands: mpsc::Receiver<WriteCommand>,
602 mut cancel: watch::Receiver<bool>,
603) -> StreamResult<IoResult>
604where
605 W: AsyncWrite + Unpin + Send + 'static,
606 Fut: Future<Output = std::io::Result<W>> + Send + 'static,
607{
608 let mut bytes = 0_u64;
609 let mut writer = tokio::select! {
610 writer = open => match writer {
611 Ok(writer) => writer,
612 Err(error) => return Ok(IoResult::failed(bytes, io_error(error))),
613 },
614 changed = cancel.changed() => {
615 let _ = changed;
616 return Ok(IoResult::failed(bytes, StreamError::Cancelled));
617 }
618 };
619
620 loop {
621 let command = tokio::select! {
622 command = commands.recv() => command,
623 changed = cancel.changed() => {
624 let _ = changed;
625 return Ok(IoResult::failed(bytes, StreamError::Cancelled));
626 }
627 };
628
629 match command {
630 Some(WriteCommand::Chunk(chunk)) => {
631 if let Err(error) = write_chunk(&mut writer, &chunk, &mut cancel, &mut bytes).await
632 {
633 return Ok(IoResult::failed(bytes, error));
634 }
635 }
636 Some(WriteCommand::Finish(upstream_status)) => {
637 let shutdown_status = shutdown_writer(&mut writer, &mut cancel).await;
638 return Ok(IoResult {
639 bytes,
640 status: upstream_status.and(shutdown_status),
641 });
642 }
643 None => {
644 let _ = shutdown_writer(&mut writer, &mut cancel).await;
645 return Ok(IoResult::failed(bytes, StreamError::Cancelled));
646 }
647 }
648 }
649}
650
651async fn write_chunk<W>(
652 writer: &mut W,
653 chunk: &[u8],
654 cancel: &mut watch::Receiver<bool>,
655 bytes: &mut u64,
656) -> StreamResult<()>
657where
658 W: AsyncWrite + Unpin,
659{
660 let mut offset = 0usize;
661 while offset < chunk.len() {
662 let written = tokio::select! {
663 written = writer.write(&chunk[offset..]) => written.map_err(io_error)?,
664 changed = cancel.changed() => {
665 let _ = changed;
666 return Err(StreamError::Cancelled);
667 }
668 };
669
670 if written == 0 {
671 return Err(write_zero_error());
672 }
673 offset += written;
674 *bytes += written as u64;
675 }
676 Ok(())
677}
678
679async fn shutdown_writer<W>(writer: &mut W, cancel: &mut watch::Receiver<bool>) -> StreamResult<()>
680where
681 W: AsyncWrite + Unpin,
682{
683 tokio::select! {
684 result = writer.flush() => result.map_err(io_error)?,
685 changed = cancel.changed() => {
686 let _ = changed;
687 return Err(StreamError::Cancelled);
688 }
689 }
690
691 tokio::select! {
692 result = writer.shutdown() => result.map_err(io_error),
693 changed = cancel.changed() => {
694 let _ = changed;
695 Err(StreamError::Cancelled)
696 }
697 }
698}
699
700fn feed_async_writer(
701 mut input: BoxStream<Vec<u8>>,
702 command_sender: mpsc::Sender<WriteCommand>,
703 done_receiver: std_mpsc::Receiver<StreamResult<IoResult>>,
704 cancelled: Arc<AtomicBool>,
705 cancel_sender: watch::Sender<bool>,
706) -> StreamResult<IoResult> {
707 let mut terminal = Ok(());
708 loop {
709 if cancelled.load(Ordering::SeqCst) {
710 terminal = Err(StreamError::Cancelled);
711 break;
712 }
713
714 match input.next() {
715 Some(Ok(chunk)) => {
716 if !send_write_command(&command_sender, WriteCommand::Chunk(chunk), &cancelled) {
717 break;
718 }
719 }
720 Some(Err(error)) => {
721 terminal = Err(error);
722 break;
723 }
724 None => break,
725 }
726 }
727
728 if cancelled.load(Ordering::SeqCst) {
729 let _ = cancel_sender.send(true);
730 } else {
731 let _ = send_write_command(&command_sender, WriteCommand::Finish(terminal), &cancelled);
732 }
733 drop(command_sender);
734
735 loop {
736 match done_receiver.recv_timeout(PARK_INTERVAL) {
737 Ok(result) => return result,
738 Err(std_mpsc::RecvTimeoutError::Timeout) => {
739 if cancelled.load(Ordering::SeqCst) {
740 let _ = cancel_sender.send(true);
744 }
745 }
746 Err(std_mpsc::RecvTimeoutError::Disconnected) => {
747 return Err(StreamError::AbruptTermination);
748 }
749 }
750 }
751}
752
753fn send_write_command(
754 sender: &mpsc::Sender<WriteCommand>,
755 mut command: WriteCommand,
756 cancelled: &AtomicBool,
757) -> bool {
758 let mut spins = 0usize;
759 loop {
760 if cancelled.load(Ordering::SeqCst) {
761 return false;
762 }
763
764 match sender.try_send(command) {
765 Ok(()) => return true,
766 Err(mpsc::error::TrySendError::Full(returned)) => {
767 command = returned;
768 backpressure_wait(&mut spins);
769 }
770 Err(mpsc::error::TrySendError::Closed(_)) => return false,
771 }
772 }
773}
774
775fn backpressure_wait(spins: &mut usize) {
776 if *spins < BACKPRESSURE_READY_SPINS {
777 *spins += 1;
778 thread::yield_now();
779 } else {
780 thread::park_timeout(BACKPRESSURE_PARK);
781 }
782}
783
784fn read_wait(spins: &mut usize) {
785 if *spins < READ_READY_SPINS {
786 *spins += 1;
787 thread::yield_now();
788 } else {
789 thread::park_timeout(PARK_INTERVAL);
790 }
791}
792
793fn tokio_file_read_source(
794 path: PathBuf,
795 chunk_size: usize,
796) -> (BoxStream<Vec<u8>>, StreamCompletion<IoResult>) {
797 async_read_source(
798 move || tokio::fs::File::open(path),
799 chunk_size,
800 FILE_INTERNAL_READ_SIZE,
801 FILE_READ_AHEAD_CHUNKS,
802 )
803}
804
805async fn tokio_file_write_open(path: Arc<PathBuf>) -> std::io::Result<tokio::fs::File> {
806 tokio::fs::OpenOptions::new()
807 .create(true)
808 .truncate(true)
809 .write(true)
810 .open(path.as_ref())
811 .await
812}
813
814fn run_async_write_sink<W, F, Fut>(
815 input: BoxStream<Vec<u8>>,
816 materializer: &crate::stream::Materializer,
817 open: Arc<F>,
818) -> StreamResult<StreamCompletion<IoResult>>
819where
820 W: AsyncWrite + Unpin + Send + 'static,
821 F: Fn() -> Fut + Send + Sync + 'static,
822 Fut: Future<Output = std::io::Result<W>> + Send + 'static,
823{
824 let (command_sender, command_receiver) = mpsc::channel(1);
825 let (cancel_sender, cancel_receiver) = watch::channel(false);
826 let (done_sender, done_receiver) = std_mpsc::sync_channel(1);
827
828 crate::stream::stream_tokio_runtime().spawn(async move {
829 let result = AssertUnwindSafe(run_async_write_task(
830 open(),
831 command_receiver,
832 cancel_receiver,
833 ))
834 .catch_unwind()
835 .await
836 .unwrap_or(Err(StreamError::AbruptTermination));
837 let _ = done_sender.send(result);
838 });
839
840 Ok(materializer.spawn_stream(move |cancelled| {
841 let mut guard = TokioCancelGuard::new(cancel_sender.clone());
842 let result = feed_async_writer(
843 input,
844 command_sender,
845 done_receiver,
846 cancelled,
847 cancel_sender,
848 );
849 guard.disarm();
850 result
851 }))
852}
853
854fn tokio_file_write_sink(path: PathBuf) -> TokioByteSink {
855 let path = Arc::new(path);
856 async_write_sink(move || tokio_file_write_open(Arc::clone(&path)))
857}
858
859pub struct TokioFileIO;
860
861impl TokioFileIO {
862 #[must_use]
872 pub fn from_path(path: impl Into<PathBuf>, chunk_size: usize) -> TokioByteSource {
873 assert!(chunk_size > 0, "chunk size must be greater than zero");
874 let path = path.into();
875 Source::from_materialized_factory(move |_materializer| {
876 let path = path.clone();
877 Ok(tokio_file_read_source(path, chunk_size))
878 })
879 }
880
881 #[must_use]
882 pub fn from_path_default(path: impl Into<PathBuf>) -> TokioByteSource {
883 Self::from_path(path, DEFAULT_CHUNK_SIZE)
884 }
885
886 #[must_use]
893 pub fn to_path(path: impl Into<PathBuf>) -> TokioByteSink {
894 tokio_file_write_sink(path.into())
895 }
896}
897
898#[cfg(feature = "io-uring-file")]
909pub struct UringFileIO;
910
911#[cfg(feature = "io-uring-file")]
912impl UringFileIO {
913 #[must_use]
920 pub fn from_path(path: impl Into<PathBuf>, chunk_size: usize) -> TokioByteSource {
921 assert!(chunk_size > 0, "chunk size must be greater than zero");
922 let path = path.into();
923 Source::from_materialized_factory(move |_materializer| {
924 let path = path.clone();
925 Ok(uring_file_read_source_or_fallback(path, chunk_size))
926 })
927 }
928
929 #[must_use]
930 pub fn from_path_default(path: impl Into<PathBuf>) -> TokioByteSource {
931 Self::from_path(path, DEFAULT_CHUNK_SIZE)
932 }
933
934 #[must_use]
939 pub fn to_path(path: impl Into<PathBuf>) -> TokioByteSink {
940 uring_file_write_sink_or_fallback(path.into())
941 }
942}
943
944#[cfg(all(feature = "io-uring-file", not(target_os = "linux")))]
945fn uring_file_read_source_or_fallback(
946 path: PathBuf,
947 chunk_size: usize,
948) -> (BoxStream<Vec<u8>>, StreamCompletion<IoResult>) {
949 tokio_file_read_source(path, chunk_size)
950}
951
952#[cfg(all(feature = "io-uring-file", not(target_os = "linux")))]
953fn uring_file_write_sink_or_fallback(path: PathBuf) -> TokioByteSink {
954 tokio_file_write_sink(path)
955}
956
957#[cfg(all(feature = "io-uring-file", target_os = "linux"))]
958type UringJob =
959 Box<dyn FnOnce() -> std::pin::Pin<Box<dyn Future<Output = ()> + 'static>> + Send + 'static>;
960
961#[cfg(all(feature = "io-uring-file", target_os = "linux"))]
962#[derive(Clone)]
963struct UringRuntimeHandle {
964 sender: mpsc::UnboundedSender<UringJob>,
965}
966
967#[cfg(all(feature = "io-uring-file", target_os = "linux"))]
968enum UringRuntimeInit {
969 Ready,
970 Failed(std::io::Error),
971}
972
973#[cfg(all(feature = "io-uring-file", target_os = "linux"))]
974fn uring_runtime_handle() -> Result<&'static UringRuntimeHandle, String> {
975 static HANDLE: std::sync::OnceLock<Result<UringRuntimeHandle, String>> =
976 std::sync::OnceLock::new();
977 HANDLE
978 .get_or_init(start_uring_runtime)
979 .as_ref()
980 .map_err(Clone::clone)
981}
982
983#[cfg(all(feature = "io-uring-file", target_os = "linux"))]
984fn start_uring_runtime() -> Result<UringRuntimeHandle, String> {
985 let (init_sender, init_receiver) = std_mpsc::sync_channel(1);
986 let (sender, receiver) = mpsc::unbounded_channel::<UringJob>();
987 thread::Builder::new()
988 .name("datum-uring-file".to_owned())
989 .spawn(
990 move || match tokio_uring::Runtime::new(&tokio_uring::builder()) {
991 Ok(runtime) => {
992 let _ = init_sender.send(UringRuntimeInit::Ready);
993 runtime.block_on(run_uring_jobs(receiver));
994 }
995 Err(error) => {
996 let _ = init_sender.send(UringRuntimeInit::Failed(error));
997 }
998 },
999 )
1000 .map_err(|error| error.to_string())?;
1001
1002 match init_receiver.recv() {
1003 Ok(UringRuntimeInit::Ready) => Ok(UringRuntimeHandle { sender }),
1004 Ok(UringRuntimeInit::Failed(error)) => Err(error.to_string()),
1005 Err(_) => Err("tokio-uring runtime thread exited".to_owned()),
1006 }
1007}
1008
1009#[cfg(all(feature = "io-uring-file", target_os = "linux"))]
1010async fn run_uring_jobs(mut receiver: mpsc::UnboundedReceiver<UringJob>) {
1011 while let Some(job) = receiver.recv().await {
1012 tokio_uring::spawn(job());
1013 }
1014}
1015
1016#[cfg(all(feature = "io-uring-file", target_os = "linux"))]
1017fn spawn_uring_job(job: UringJob) -> Result<(), String> {
1018 uring_runtime_handle()?
1019 .sender
1020 .send(job)
1021 .map_err(|_| "tokio-uring runtime thread exited".to_owned())
1022}
1023
1024#[cfg(all(feature = "io-uring-file", target_os = "linux"))]
1025fn uring_file_read_source_or_fallback(
1026 path: PathBuf,
1027 chunk_size: usize,
1028) -> (BoxStream<Vec<u8>>, StreamCompletion<IoResult>) {
1029 let internal_read_size = FILE_INTERNAL_READ_SIZE.max(chunk_size);
1030 let (item_sender, item_receiver) = mpsc::channel(FILE_READ_AHEAD_CHUNKS);
1031 let (cancel_sender, cancel_receiver) = watch::channel(false);
1032 let (mat_sender, mat_receiver) = oneshot::channel();
1033 let terminal = Arc::new(Mutex::new(None));
1034 let terminal_for_task = Arc::clone(&terminal);
1035 let waker = ConsumerWaker::new();
1036 let producer_waker = waker.clone();
1037 let uring_path = path.clone();
1038
1039 let result = spawn_uring_job(Box::new(move || {
1040 Box::pin(async move {
1041 let task_result = AssertUnwindSafe(run_uring_read_task(
1042 uring_path,
1043 chunk_size,
1044 internal_read_size,
1045 item_sender,
1046 cancel_receiver,
1047 Arc::clone(&terminal_for_task),
1048 producer_waker,
1049 ))
1050 .catch_unwind()
1051 .await
1052 .unwrap_or_else(|_| {
1053 finish_terminal(
1054 &terminal_for_task,
1055 DemandTerminal::Error(StreamError::AbruptTermination),
1056 );
1057 Err(StreamError::AbruptTermination)
1058 });
1059 let _ = mat_sender.send(task_result);
1060 })
1061 }));
1062
1063 if result.is_err() {
1064 return tokio_file_read_source(path, chunk_size);
1065 }
1066
1067 (
1068 Box::new(BoundedByteSourceStream {
1069 receiver: item_receiver,
1070 cancel: cancel_sender,
1071 terminal,
1072 done: false,
1073 waker,
1074 }) as BoxStream<Vec<u8>>,
1075 StreamCompletion::from_receiver(mat_receiver, None),
1076 )
1077}
1078
1079#[cfg(all(feature = "io-uring-file", target_os = "linux"))]
1080async fn run_uring_read_task(
1081 path: PathBuf,
1082 chunk_size: usize,
1083 internal_read_size: usize,
1084 items: mpsc::Sender<DemandResponse<Vec<u8>>>,
1085 mut cancel: watch::Receiver<bool>,
1086 terminal: Arc<Mutex<Option<DemandTerminal>>>,
1087 waker: ConsumerWaker,
1088) -> StreamResult<IoResult> {
1089 let mut bytes = 0_u64;
1090 let file = tokio::select! {
1091 file = tokio_uring::fs::File::open(path) => match file {
1092 Ok(file) => file,
1093 Err(error) => {
1094 let error = io_error(error);
1095 finish_terminal(&terminal, DemandTerminal::Error(error.clone()));
1096 let _ = send_read_item(&items, DemandResponse::Error(error.clone()), &mut cancel, &waker).await;
1097 return Ok(IoResult::failed(bytes, error));
1098 }
1099 },
1100 changed = cancel.changed() => {
1101 let _ = changed;
1102 finish_terminal(&terminal, DemandTerminal::Error(StreamError::Cancelled));
1103 return Ok(IoResult::failed(bytes, StreamError::Cancelled));
1104 }
1105 };
1106
1107 let mut offset = 0_u64;
1108 let mut buffer = Vec::with_capacity(internal_read_size);
1109 let mut pending_tail = Vec::with_capacity(chunk_size);
1110 loop {
1111 let read_buffer = std::mem::take(&mut buffer);
1112 let (read, returned_buffer) = tokio::select! {
1113 result = file.read_at(read_buffer, offset) => result,
1114 changed = cancel.changed() => {
1115 let _ = changed;
1116 finish_terminal(&terminal, DemandTerminal::Error(StreamError::Cancelled));
1117 return Ok(IoResult::failed(bytes, StreamError::Cancelled));
1118 }
1119 };
1120 buffer = returned_buffer;
1121
1122 match read {
1123 Ok(0) => {
1124 if !pending_tail.is_empty()
1125 && !send_read_item(
1126 &items,
1127 DemandResponse::Item(std::mem::take(&mut pending_tail)),
1128 &mut cancel,
1129 &waker,
1130 )
1131 .await
1132 {
1133 finish_terminal(&terminal, DemandTerminal::Error(StreamError::Cancelled));
1134 return Ok(IoResult::failed(bytes, StreamError::Cancelled));
1135 }
1136 finish_terminal(&terminal, DemandTerminal::Complete);
1137 let _ = send_read_item(&items, DemandResponse::Complete, &mut cancel, &waker).await;
1138 return Ok(IoResult::succeeded(bytes));
1139 }
1140 Ok(read) => {
1141 bytes += read as u64;
1142 offset += read as u64;
1143 if !send_read_chunks(
1144 &items,
1145 chunk_size,
1146 &mut pending_tail,
1147 &buffer[..read],
1148 &mut cancel,
1149 &waker,
1150 )
1151 .await
1152 {
1153 finish_terminal(&terminal, DemandTerminal::Error(StreamError::Cancelled));
1154 return Ok(IoResult::failed(bytes, StreamError::Cancelled));
1155 }
1156 buffer.clear();
1157 }
1158 Err(error) => {
1159 let error = io_error(error);
1160 finish_terminal(&terminal, DemandTerminal::Error(error.clone()));
1161 let _ = send_read_item(
1162 &items,
1163 DemandResponse::Error(error.clone()),
1164 &mut cancel,
1165 &waker,
1166 )
1167 .await;
1168 return Ok(IoResult::failed(bytes, error));
1169 }
1170 }
1171 }
1172}
1173
1174#[cfg(all(feature = "io-uring-file", target_os = "linux"))]
1175fn uring_file_write_sink_or_fallback(path: PathBuf) -> TokioByteSink {
1176 let path = Arc::new(path);
1177 Sink::from_runner(move |input, materializer| {
1178 let path = Arc::clone(&path);
1179 let (command_sender, command_receiver) = mpsc::channel(1);
1180 let (cancel_sender, cancel_receiver) = watch::channel(false);
1181 let (done_sender, done_receiver) = std_mpsc::sync_channel(1);
1182 let uring_path = Arc::clone(&path);
1183
1184 let result = spawn_uring_job(Box::new(move || {
1185 Box::pin(async move {
1186 let task_result = AssertUnwindSafe(run_uring_write_task(
1187 uring_path,
1188 command_receiver,
1189 cancel_receiver,
1190 ))
1191 .catch_unwind()
1192 .await
1193 .unwrap_or(Err(StreamError::AbruptTermination));
1194 let _ = done_sender.send(task_result);
1195 })
1196 }));
1197
1198 if result.is_err() {
1199 let fallback_path = Arc::clone(&path);
1200 return run_async_write_sink::<tokio::fs::File, _, _>(
1201 input,
1202 materializer,
1203 Arc::new(move || tokio_file_write_open(Arc::clone(&fallback_path))),
1204 );
1205 }
1206
1207 Ok(materializer.spawn_stream(move |cancelled| {
1208 let mut guard = TokioCancelGuard::new(cancel_sender.clone());
1209 let result = feed_async_writer(
1210 input,
1211 command_sender,
1212 done_receiver,
1213 cancelled,
1214 cancel_sender,
1215 );
1216 guard.disarm();
1217 result
1218 }))
1219 })
1220}
1221
1222#[cfg(all(feature = "io-uring-file", target_os = "linux"))]
1223async fn run_uring_write_task(
1224 path: Arc<PathBuf>,
1225 mut commands: mpsc::Receiver<WriteCommand>,
1226 mut cancel: watch::Receiver<bool>,
1227) -> StreamResult<IoResult> {
1228 let mut bytes = 0_u64;
1229 let file = tokio::select! {
1230 file = tokio_uring::fs::File::create(path.as_ref()) => match file {
1231 Ok(file) => file,
1232 Err(error) => return Ok(IoResult::failed(bytes, io_error(error))),
1233 },
1234 changed = cancel.changed() => {
1235 let _ = changed;
1236 return Ok(IoResult::failed(bytes, StreamError::Cancelled));
1237 }
1238 };
1239
1240 let mut offset = 0_u64;
1241 loop {
1242 let command = tokio::select! {
1243 command = commands.recv() => command,
1244 changed = cancel.changed() => {
1245 let _ = changed;
1246 return Ok(IoResult::failed(bytes, StreamError::Cancelled));
1247 }
1248 };
1249
1250 match command {
1251 Some(WriteCommand::Chunk(chunk)) => {
1252 if let Err(error) =
1253 write_uring_chunk(&file, chunk, &mut offset, &mut cancel, &mut bytes).await
1254 {
1255 return Ok(IoResult::failed(bytes, error));
1256 }
1257 }
1258 Some(WriteCommand::Finish(upstream_status)) => {
1259 let close_status = tokio::select! {
1260 result = file.close() => result.map_err(io_error),
1261 changed = cancel.changed() => {
1262 let _ = changed;
1263 Err(StreamError::Cancelled)
1264 }
1265 };
1266 return Ok(IoResult {
1267 bytes,
1268 status: upstream_status.and(close_status),
1269 });
1270 }
1271 None => {
1272 let _ = file.close().await;
1273 return Ok(IoResult::failed(bytes, StreamError::Cancelled));
1274 }
1275 }
1276 }
1277}
1278
1279#[cfg(all(feature = "io-uring-file", target_os = "linux"))]
1280async fn write_uring_chunk(
1281 file: &tokio_uring::fs::File,
1282 chunk: Vec<u8>,
1283 offset: &mut u64,
1284 cancel: &mut watch::Receiver<bool>,
1285 bytes: &mut u64,
1286) -> StreamResult<()> {
1287 use tokio_uring::buf::BoundedBuf;
1288
1289 let len = chunk.len();
1290 if len == 0 {
1291 return Ok(());
1292 }
1293
1294 let mut written = 0usize;
1295 let mut buffer = chunk;
1296 while written < len {
1297 let slice = buffer.slice(written..len);
1298 let (result, returned) = tokio::select! {
1299 result = file.write_at(slice, *offset).submit() => result,
1300 changed = cancel.changed() => {
1301 let _ = changed;
1302 return Err(StreamError::Cancelled);
1303 }
1304 };
1305 buffer = returned.into_inner();
1306
1307 match result {
1308 Ok(0) => return Err(write_zero_error()),
1309 Ok(n) => {
1310 written += n;
1311 *offset += n as u64;
1312 *bytes += n as u64;
1313 }
1314 Err(error) => return Err(io_error(error)),
1315 }
1316 }
1317
1318 Ok(())
1319}
1320
1321#[derive(Debug, Clone, Copy, PartialEq, Eq)]
1322pub struct TcpConnection {
1323 pub local_addr: SocketAddr,
1324 pub remote_addr: SocketAddr,
1325}
1326
1327impl TcpConnection {
1328 #[must_use]
1329 pub fn local_addr(&self) -> SocketAddr {
1330 self.local_addr
1331 }
1332
1333 #[must_use]
1334 pub fn remote_addr(&self) -> SocketAddr {
1335 self.remote_addr
1336 }
1337}
1338
1339#[derive(Debug, Clone, Copy, PartialEq, Eq)]
1340pub struct TcpBinding {
1341 pub local_addr: SocketAddr,
1342}
1343
1344impl TcpBinding {
1345 #[must_use]
1346 pub fn local_addr(&self) -> SocketAddr {
1347 self.local_addr
1348 }
1349}
1350
1351pub struct TcpIncomingConnection {
1357 connection: TcpConnection,
1358 source: TokioByteSource,
1359 sink: TokioByteSink,
1360}
1361
1362impl TcpIncomingConnection {
1363 #[must_use]
1364 pub fn local_addr(&self) -> SocketAddr {
1365 self.connection.local_addr
1366 }
1367
1368 #[must_use]
1369 pub fn remote_addr(&self) -> SocketAddr {
1370 self.connection.remote_addr
1371 }
1372
1373 #[must_use]
1374 pub fn connection(&self) -> TcpConnection {
1375 self.connection
1376 }
1377
1378 #[must_use]
1379 pub fn into_parts(self) -> (TokioByteSource, TokioByteSink) {
1380 (self.source, self.sink)
1381 }
1382
1383 #[must_use]
1384 pub fn into_flow(self) -> Flow<Vec<u8>, Vec<u8>, NotUsed> {
1385 Flow::from_sink_and_source_coupled(self.sink, self.source)
1386 .map_materialized_value(|_| NotUsed)
1387 }
1388}
1389
1390pub struct TokioTcp;
1391
1392impl TokioTcp {
1393 #[must_use]
1400 pub fn outgoing_connection<A>(
1401 addr: A,
1402 chunk_size: usize,
1403 ) -> Flow<Vec<u8>, Vec<u8>, StreamCompletion<TcpConnection>>
1404 where
1405 A: ToSocketAddrs + Clone + Send + Sync + 'static,
1406 {
1407 assert!(chunk_size > 0, "chunk size must be greater than zero");
1408 Flow::future_flow(move || {
1409 let addr = addr.clone();
1410 async move {
1411 let stream = TcpStream::connect(addr).await.map_err(io_error)?;
1412 Ok(tcp_flow_from_stream(stream, chunk_size))
1413 }
1414 })
1415 }
1416
1417 #[must_use]
1418 pub fn outgoing_connection_default<A>(
1419 addr: A,
1420 ) -> Flow<Vec<u8>, Vec<u8>, StreamCompletion<TcpConnection>>
1421 where
1422 A: ToSocketAddrs + Clone + Send + Sync + 'static,
1423 {
1424 Self::outgoing_connection(addr, DEFAULT_CHUNK_SIZE)
1425 }
1426
1427 #[must_use]
1433 pub fn bind<A>(
1434 addr: A,
1435 chunk_size: usize,
1436 ) -> Source<TcpIncomingConnection, StreamCompletion<TcpBinding>>
1437 where
1438 A: ToSocketAddrs + Clone + Send + Sync + 'static,
1439 {
1440 assert!(chunk_size > 0, "chunk size must be greater than zero");
1441 Source::from_materialized_factory(move |_materializer| {
1442 let (demand_sender, demand_receiver) = mpsc::channel(1);
1443 let (cancel_sender, cancel_receiver) = watch::channel(false);
1444 let (binding_sender, binding_receiver) = oneshot::channel();
1445 let terminal = Arc::new(Mutex::new(None));
1446 let terminal_for_task = Arc::clone(&terminal);
1447 let addr = addr.clone();
1448
1449 crate::stream::stream_tokio_runtime().spawn(async move {
1450 let result = AssertUnwindSafe(run_tcp_bind_task(
1451 addr,
1452 chunk_size,
1453 demand_receiver,
1454 cancel_receiver,
1455 binding_sender,
1456 Arc::clone(&terminal_for_task),
1457 ))
1458 .catch_unwind()
1459 .await;
1460 if result.is_err() {
1461 finish_terminal(
1462 &terminal_for_task,
1463 DemandTerminal::Error(StreamError::AbruptTermination),
1464 );
1465 }
1466 });
1467
1468 Ok((
1469 Box::new(DemandSourceStream {
1470 demands: demand_sender,
1471 cancel: cancel_sender,
1472 terminal,
1473 done: false,
1474 }) as BoxStream<TcpIncomingConnection>,
1475 StreamCompletion::from_receiver(binding_receiver, None),
1476 ))
1477 })
1478 }
1479
1480 #[must_use]
1481 pub fn bind_default<A>(addr: A) -> Source<TcpIncomingConnection, StreamCompletion<TcpBinding>>
1482 where
1483 A: ToSocketAddrs + Clone + Send + Sync + 'static,
1484 {
1485 Self::bind(addr, DEFAULT_CHUNK_SIZE)
1486 }
1487}
1488
1489fn tcp_flow_from_stream(
1490 stream: TcpStream,
1491 chunk_size: usize,
1492) -> Flow<Vec<u8>, Vec<u8>, TcpConnection> {
1493 let connection = TcpConnection {
1494 local_addr: stream
1495 .local_addr()
1496 .expect("connected TCP stream has local address"),
1497 remote_addr: stream
1498 .peer_addr()
1499 .expect("connected TCP stream has peer address"),
1500 };
1501 let (read_half, write_half) = stream.into_split();
1502 let source = single_use_async_read_source(read_half, chunk_size);
1503 let sink = single_use_async_write_sink(write_half);
1504 Flow::from_sink_and_source(sink, source).map_materialized_value(move |_| connection)
1511}
1512
1513fn single_use_async_read_source<R>(reader: R, chunk_size: usize) -> TokioByteSource
1514where
1515 R: AsyncRead + Unpin + Send + 'static,
1516{
1517 let reader = Arc::new(Mutex::new(Some(reader)));
1518 Source::from_materialized_factory(move |_materializer| {
1519 let reader = Arc::clone(&reader);
1520 Ok(async_read_source(
1521 move || async move {
1522 reader
1523 .lock()
1524 .expect("single-use async reader poisoned")
1525 .take()
1526 .ok_or_else(|| std::io::Error::other("async reader already materialized"))
1527 },
1528 chunk_size,
1529 chunk_size,
1530 TCP_READ_AHEAD_CHUNKS,
1531 ))
1532 })
1533}
1534
1535fn single_use_async_write_sink<W>(writer: W) -> TokioByteSink
1536where
1537 W: AsyncWrite + Unpin + Send + 'static,
1538{
1539 let writer = Arc::new(Mutex::new(Some(writer)));
1540 async_write_sink(move || {
1541 let writer = Arc::clone(&writer);
1542 async move {
1543 writer
1544 .lock()
1545 .expect("single-use async writer poisoned")
1546 .take()
1547 .ok_or_else(|| std::io::Error::other("async writer already materialized"))
1548 }
1549 })
1550}
1551
1552async fn run_tcp_bind_task<A>(
1553 addr: A,
1554 chunk_size: usize,
1555 mut demands: mpsc::Receiver<std_mpsc::Sender<DemandResponse<TcpIncomingConnection>>>,
1556 mut cancel: watch::Receiver<bool>,
1557 binding_sender: oneshot::Sender<StreamResult<TcpBinding>>,
1558 terminal: Arc<Mutex<Option<DemandTerminal>>>,
1559) where
1560 A: ToSocketAddrs + Send + 'static,
1561{
1562 let listener = match TcpListener::bind(addr).await {
1563 Ok(listener) => listener,
1564 Err(error) => {
1565 let error = io_error(error);
1566 finish_terminal(&terminal, DemandTerminal::Error(error.clone()));
1567 let _ = binding_sender.send(Err(error));
1568 return;
1569 }
1570 };
1571 let local_addr = match listener.local_addr() {
1572 Ok(local_addr) => local_addr,
1573 Err(error) => {
1574 let error = io_error(error);
1575 finish_terminal(&terminal, DemandTerminal::Error(error.clone()));
1576 let _ = binding_sender.send(Err(error));
1577 return;
1578 }
1579 };
1580 let _ = binding_sender.send(Ok(TcpBinding { local_addr }));
1581
1582 loop {
1583 let Some(reply) = next_demand(&mut demands, &mut cancel).await else {
1584 finish_terminal(&terminal, DemandTerminal::Error(StreamError::Cancelled));
1585 return;
1586 };
1587
1588 let (stream, remote_addr) = loop {
1589 let accepted = tokio::select! {
1590 accepted = listener.accept() => accepted,
1591 changed = cancel.changed() => {
1592 let _ = changed;
1593 finish_terminal(&terminal, DemandTerminal::Error(StreamError::Cancelled));
1594 return;
1595 }
1596 };
1597
1598 match accepted {
1599 Ok(accepted) => break accepted,
1600 Err(error) if is_transient_accept_error(&error) => continue,
1601 Err(error) => {
1602 let error = io_error(error);
1606 finish_terminal(&terminal, DemandTerminal::Error(error.clone()));
1607 let _ = reply.send(DemandResponse::Error(error));
1608 return;
1609 }
1610 }
1611 };
1612
1613 let incoming = tcp_incoming_connection(stream, remote_addr, local_addr, chunk_size);
1614 if reply.send(DemandResponse::Item(incoming)).is_err() {
1615 finish_terminal(&terminal, DemandTerminal::Error(StreamError::Cancelled));
1616 return;
1617 }
1618 }
1619}
1620
1621fn is_transient_accept_error(error: &std::io::Error) -> bool {
1622 matches!(
1623 error.kind(),
1624 std::io::ErrorKind::Interrupted
1625 | std::io::ErrorKind::ConnectionAborted
1626 | std::io::ErrorKind::ConnectionReset
1627 ) || error.raw_os_error().is_some_and(is_transient_accept_errno)
1628}
1629
1630#[cfg(target_os = "linux")]
1631fn is_transient_accept_errno(code: i32) -> bool {
1632 matches!(code, 4 | 103 | 104)
1633}
1634
1635#[cfg(not(target_os = "linux"))]
1636fn is_transient_accept_errno(_code: i32) -> bool {
1637 false
1638}
1639
1640fn tcp_incoming_connection(
1641 stream: TcpStream,
1642 remote_addr: SocketAddr,
1643 local_addr: SocketAddr,
1644 chunk_size: usize,
1645) -> TcpIncomingConnection {
1646 let connection = TcpConnection {
1647 local_addr,
1648 remote_addr,
1649 };
1650 let (read_half, write_half) = stream.into_split();
1651 let source = single_use_async_read_source(read_half, chunk_size);
1652 let sink = single_use_async_write_sink(write_half);
1653 TcpIncomingConnection {
1654 connection,
1655 source,
1656 sink,
1657 }
1658}
1659
1660#[cfg(test)]
1661mod tests {
1662 use super::*;
1663 use crate::{Framing, Keep, Sink, Source};
1664 use std::pin::Pin;
1665 use std::sync::atomic::{AtomicBool as StdAtomicBool, Ordering as StdOrdering};
1666 use std::task::{Context, Poll};
1667 use std::time::{Duration, Instant, SystemTime, UNIX_EPOCH};
1668
1669 fn unique_temp_path(name: &str) -> PathBuf {
1670 let nanos = SystemTime::now()
1671 .duration_since(UNIX_EPOCH)
1672 .expect("clock after epoch")
1673 .as_nanos();
1674 std::env::temp_dir().join(format!(
1675 "datum-wp12b-{name}-{}-{nanos}.bin",
1676 std::process::id()
1677 ))
1678 }
1679
1680 fn wait_until(timeout: Duration, condition: impl Fn() -> bool) -> bool {
1681 let deadline = Instant::now() + timeout;
1682 while Instant::now() < deadline {
1683 if condition() {
1684 return true;
1685 }
1686 thread::sleep(Duration::from_millis(5));
1687 }
1688 condition()
1689 }
1690
1691 struct PendingWriter {
1692 polled: Arc<StdAtomicBool>,
1693 dropped: Arc<StdAtomicBool>,
1694 }
1695
1696 impl AsyncWrite for PendingWriter {
1697 fn poll_write(
1698 self: Pin<&mut Self>,
1699 _cx: &mut Context<'_>,
1700 _buf: &[u8],
1701 ) -> Poll<std::io::Result<usize>> {
1702 self.polled.store(true, StdOrdering::SeqCst);
1703 Poll::Pending
1704 }
1705
1706 fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
1707 self.polled.store(true, StdOrdering::SeqCst);
1708 Poll::Pending
1709 }
1710
1711 fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
1712 self.polled.store(true, StdOrdering::SeqCst);
1713 Poll::Pending
1714 }
1715 }
1716
1717 impl Drop for PendingWriter {
1718 fn drop(&mut self) {
1719 self.dropped.store(true, StdOrdering::SeqCst);
1720 }
1721 }
1722
1723 #[test]
1724 fn tokio_file_io_round_trips_bytes_and_reports_counts() {
1725 let path = unique_temp_path("roundtrip");
1726 let write_completion = Source::from_iter([b"ab".to_vec(), b"cd".to_vec()])
1727 .run_with(TokioFileIO::to_path(path.clone()))
1728 .expect("tokio file sink materializes");
1729 let write_result = write_completion.wait().expect("tokio file write completes");
1730 assert_eq!(write_result.bytes(), 4);
1731 assert_eq!(write_result.status(), Ok(()));
1732
1733 let (read_completion, collected) = TokioFileIO::from_path(path.clone(), 2)
1734 .to_mat(Sink::collect(), Keep::both)
1735 .run()
1736 .expect("tokio file source materializes");
1737 assert_eq!(
1738 collected.wait().expect("collect completes"),
1739 vec![b"ab".to_vec(), b"cd".to_vec()]
1740 );
1741 let read_result = read_completion.wait().expect("read completion available");
1742 assert_eq!(read_result.bytes(), 4);
1743 assert_eq!(read_result.status(), Ok(()));
1744
1745 std::fs::remove_file(path).expect("remove roundtrip file");
1746 }
1747
1748 #[test]
1749 fn tokio_file_source_surfaces_open_failure() {
1750 let missing = unique_temp_path("missing");
1751 let (read_completion, collected) = TokioFileIO::from_path(missing, 4)
1752 .to_mat(Sink::collect(), Keep::both)
1753 .run()
1754 .expect("tokio file source materializes despite open failure");
1755 let stream_error = collected.wait().expect_err("collect fails");
1756 assert!(matches!(stream_error, StreamError::Failed(_)));
1757 let read_result = read_completion.wait().expect("io result available");
1758 assert_eq!(read_result.bytes(), 0);
1759 assert!(matches!(read_result.status(), Err(StreamError::Failed(_))));
1760 }
1761
1762 #[test]
1763 fn tokio_file_source_composes_with_framing_and_sink() {
1764 let path = unique_temp_path("framing");
1765 std::fs::write(&path, b"alpha\nbeta\ngamma\n").expect("write framed seed file");
1766
1767 let frames = TokioFileIO::from_path(path.clone(), 5)
1768 .via(Framing::delimiter(b"\n".to_vec(), 64, true))
1769 .run_with(Sink::collect())
1770 .expect("framed file stream materializes")
1771 .wait()
1772 .expect("framed file stream completes");
1773
1774 assert_eq!(
1775 frames,
1776 vec![b"alpha".to_vec(), b"beta".to_vec(), b"gamma".to_vec()]
1777 );
1778 std::fs::remove_file(path).expect("remove framed file");
1779 }
1780
1781 #[test]
1782 fn tokio_file_source_preserves_requested_chunk_boundaries() {
1783 let path = unique_temp_path("chunk-boundaries");
1784 let chunk_size = 8192;
1785 let tail_size = 13;
1786 let data_len = FILE_INTERNAL_READ_SIZE + tail_size;
1787 let data: Vec<u8> = (0..data_len).map(|index| (index % 251) as u8).collect();
1788 std::fs::write(&path, &data).expect("write chunk boundary seed file");
1789
1790 let (read_completion, chunks) = TokioFileIO::from_path(path.clone(), chunk_size)
1791 .to_mat(Sink::collect(), Keep::both)
1792 .run()
1793 .expect("tokio file source materializes");
1794 let chunks = chunks.wait().expect("chunk boundary stream completes");
1795
1796 assert!(chunks.len() > 1);
1797 for chunk in &chunks[..chunks.len() - 1] {
1798 assert_eq!(chunk.len(), chunk_size);
1799 }
1800 assert_eq!(chunks.last().expect("tail chunk exists").len(), tail_size);
1801 let reassembled: Vec<u8> = chunks.into_iter().flatten().collect();
1802 assert_eq!(reassembled, data);
1803 assert_eq!(
1804 read_completion
1805 .wait()
1806 .expect("read completion available")
1807 .bytes(),
1808 data_len as u64
1809 );
1810
1811 std::fs::remove_file(path).expect("remove chunk boundary file");
1812 }
1813
1814 #[cfg(all(feature = "io-uring-file", target_os = "linux"))]
1815 #[test]
1816 fn uring_file_io_round_trips_bytes_and_reports_counts() {
1817 let path = unique_temp_path("uring-roundtrip");
1818 let write_completion = Source::from_iter([b"ab".to_vec(), b"cd".to_vec()])
1819 .run_with(UringFileIO::to_path(path.clone()))
1820 .expect("uring file sink materializes");
1821 let write_result = write_completion.wait().expect("uring file write completes");
1822 assert_eq!(write_result.bytes(), 4);
1823 assert_eq!(write_result.status(), Ok(()));
1824
1825 let (read_completion, collected) = UringFileIO::from_path(path.clone(), 2)
1826 .to_mat(Sink::collect(), Keep::both)
1827 .run()
1828 .expect("uring file source materializes");
1829 assert_eq!(
1830 collected.wait().expect("collect completes"),
1831 vec![b"ab".to_vec(), b"cd".to_vec()]
1832 );
1833 let read_result = read_completion.wait().expect("read completion available");
1834 assert_eq!(read_result.bytes(), 4);
1835 assert_eq!(read_result.status(), Ok(()));
1836
1837 std::fs::remove_file(path).expect("remove uring roundtrip file");
1838 }
1839
1840 #[cfg(all(feature = "io-uring-file", target_os = "linux"))]
1841 #[test]
1842 fn uring_file_source_surfaces_open_failure() {
1843 let missing = unique_temp_path("uring-missing");
1844 let (read_completion, collected) = UringFileIO::from_path(missing, 4)
1845 .to_mat(Sink::collect(), Keep::both)
1846 .run()
1847 .expect("uring file source materializes despite open failure");
1848 let stream_error = collected.wait().expect_err("collect fails");
1849 assert!(matches!(stream_error, StreamError::Failed(_)));
1850 let read_result = read_completion.wait().expect("io result available");
1851 assert_eq!(read_result.bytes(), 0);
1852 assert!(matches!(read_result.status(), Err(StreamError::Failed(_))));
1853 }
1854
1855 #[cfg(all(feature = "io-uring-file", target_os = "linux"))]
1856 #[test]
1857 fn uring_file_sink_surfaces_open_failure() {
1858 let path = unique_temp_path("uring-sink-dir");
1859 std::fs::create_dir(&path).expect("create sink failure directory");
1860
1861 let completion = Source::single(b"blocked".to_vec())
1862 .run_with(UringFileIO::to_path(path.clone()))
1863 .expect("uring file sink materializes despite open failure");
1864 let result = completion.wait().expect("uring sink result available");
1865 assert_eq!(result.bytes(), 0);
1866 assert!(matches!(result.status(), Err(StreamError::Failed(_))));
1867
1868 std::fs::remove_dir(path).expect("remove sink failure directory");
1869 }
1870
1871 #[cfg(all(feature = "io-uring-file", target_os = "linux"))]
1872 #[test]
1873 fn uring_file_source_preserves_requested_chunk_boundaries() {
1874 let path = unique_temp_path("uring-chunk-boundaries");
1875 let chunk_size = 8192;
1876 let tail_size = 13;
1877 let data_len = FILE_INTERNAL_READ_SIZE + tail_size;
1878 let data: Vec<u8> = (0..data_len).map(|index| (index % 251) as u8).collect();
1879 std::fs::write(&path, &data).expect("write uring chunk boundary seed file");
1880
1881 let (read_completion, chunks) = UringFileIO::from_path(path.clone(), chunk_size)
1882 .to_mat(Sink::collect(), Keep::both)
1883 .run()
1884 .expect("uring file source materializes");
1885 let chunks = chunks.wait().expect("chunk boundary stream completes");
1886
1887 assert!(chunks.len() > 1);
1888 for chunk in &chunks[..chunks.len() - 1] {
1889 assert_eq!(chunk.len(), chunk_size);
1890 }
1891 assert_eq!(chunks.last().expect("tail chunk exists").len(), tail_size);
1892 let reassembled: Vec<u8> = chunks.into_iter().flatten().collect();
1893 assert_eq!(reassembled, data);
1894 assert_eq!(
1895 read_completion
1896 .wait()
1897 .expect("read completion available")
1898 .bytes(),
1899 data_len as u64
1900 );
1901
1902 std::fs::remove_file(path).expect("remove uring chunk boundary file");
1903 }
1904
1905 #[cfg(all(feature = "io-uring-file", target_os = "linux"))]
1906 #[test]
1907 fn uring_file_source_cancellation_reports_cancelled() {
1908 let path = unique_temp_path("uring-cancel");
1909 let data_len = FILE_INTERNAL_READ_SIZE * 4;
1910 std::fs::write(&path, vec![b'z'; data_len]).expect("write uring cancel seed file");
1911
1912 let (completion, ignored) = UringFileIO::from_path(path.clone(), 8192)
1913 .take(1)
1914 .to_mat(Sink::ignore(), Keep::both)
1915 .run()
1916 .expect("uring cancellation source materializes");
1917 ignored.wait().expect("downstream ignore completes");
1918 let result = completion.wait().expect("uring read completion available");
1919
1920 assert!(result.bytes() > 0);
1921 assert!(matches!(result.status(), Err(StreamError::Cancelled)));
1922
1923 std::fs::remove_file(path).expect("remove uring cancel file");
1924 }
1925
1926 #[test]
1927 fn tokio_sink_cancellation_unblocks_pending_writer_completion_wait() {
1928 let polled = Arc::new(StdAtomicBool::new(false));
1929 let dropped = Arc::new(StdAtomicBool::new(false));
1930 let completion = Source::single(b"blocked".to_vec())
1931 .run_with(async_write_sink({
1932 let polled = Arc::clone(&polled);
1933 let dropped = Arc::clone(&dropped);
1934 move || {
1935 let polled = Arc::clone(&polled);
1936 let dropped = Arc::clone(&dropped);
1937 async move { Ok(PendingWriter { polled, dropped }) }
1938 }
1939 }))
1940 .expect("pending writer sink materializes");
1941
1942 assert!(wait_until(Duration::from_secs(1), || {
1943 polled.load(StdOrdering::SeqCst)
1944 }));
1945 drop(completion);
1946 assert!(wait_until(Duration::from_secs(1), || {
1947 dropped.load(StdOrdering::SeqCst)
1948 }));
1949 }
1950
1951 #[test]
1952 fn tokio_tcp_accept_error_classifier_retries_only_connection_races() {
1953 assert!(is_transient_accept_error(&std::io::Error::new(
1954 std::io::ErrorKind::Interrupted,
1955 "interrupted"
1956 )));
1957 assert!(is_transient_accept_error(&std::io::Error::new(
1958 std::io::ErrorKind::ConnectionAborted,
1959 "aborted before accept"
1960 )));
1961 assert!(is_transient_accept_error(&std::io::Error::new(
1962 std::io::ErrorKind::ConnectionReset,
1963 "reset before accept"
1964 )));
1965 assert!(!is_transient_accept_error(&std::io::Error::other(
1966 "fd pressure"
1967 )));
1968 }
1969
1970 #[test]
1971 fn tokio_source_cancellation_observed_promptly_under_wake_on_send() {
1972 let cancelled_detected = Arc::new(AtomicBool::new(false));
1985 let detected = Arc::clone(&cancelled_detected);
1986
1987 let source = Source::from_materialized_factory(move |_materializer| {
1988 let d = Arc::clone(&detected);
1989 let mut i = 0_u64;
1990 let stream: BoxStream<Vec<u8>> = Box::new(std::iter::from_fn(move || {
1991 let is_cancelled = || {
1994 crate::stream::current_stream_cancelled()
1995 .as_ref()
1996 .is_some_and(|c| c.load(Ordering::SeqCst))
1997 };
1998 if is_cancelled() {
1999 d.store(true, Ordering::SeqCst);
2000 return Some(Err(StreamError::Cancelled));
2001 }
2002 if i < 4 {
2003 i += 1;
2004 return Some(Ok(vec![0_u8; 8192]));
2005 }
2006 std::thread::park_timeout(Duration::from_millis(50));
2009 if is_cancelled() {
2010 d.store(true, Ordering::SeqCst);
2011 return Some(Err(StreamError::Cancelled));
2012 }
2013 Some(Ok(vec![0_u8]))
2016 }));
2017 Ok((stream, NotUsed))
2018 });
2019
2020 let completion = source
2021 .run_with(Sink::ignore())
2022 .expect("cancellation source materializes");
2023
2024 let cancel_thread = thread::spawn(move || {
2025 thread::sleep(Duration::from_millis(20));
2026 drop(completion);
2027 });
2028
2029 let deadline = Instant::now() + Duration::from_secs(3);
2031 loop {
2032 if cancelled_detected.load(Ordering::SeqCst) {
2033 break;
2034 }
2035 assert!(
2036 Instant::now() < deadline,
2037 "cancellation not observed within 3 s; the park‑wake path may be broken"
2038 );
2039 thread::park_timeout(Duration::from_millis(10));
2040 }
2041 cancel_thread.join().expect("cancellation thread joins");
2042 }
2043
2044 #[test]
2045 fn tokio_tcp_bind_and_outgoing_connection_echo_round_trip() {
2046 let (binding_completion, incoming_completion) = TokioTcp::bind("127.0.0.1:0", 1024)
2047 .to_mat(Sink::head(), Keep::both)
2048 .run()
2049 .expect("tcp bind source materializes");
2050 let binding = binding_completion.wait().expect("tcp binding succeeds");
2051
2052 let client_completion = Source::single(b"ping".to_vec())
2053 .via(TokioTcp::outgoing_connection(binding.local_addr(), 1024))
2054 .run_with(Sink::head())
2055 .expect("client stream materializes");
2056
2057 let incoming = incoming_completion
2058 .wait()
2059 .expect("incoming connection accepted");
2060 let (incoming_source, incoming_sink) = incoming.into_parts();
2061 let server_read = incoming_source
2062 .run_with(Sink::head())
2063 .expect("server read materializes")
2064 .wait()
2065 .expect("server reads request");
2066 assert_eq!(server_read, b"ping".to_vec());
2067
2068 let server_write = Source::single(server_read)
2069 .run_with(incoming_sink)
2070 .expect("server write materializes");
2071 let write_result = server_write.wait().expect("server write completes");
2072 assert_eq!(write_result.bytes(), 4);
2073 assert_eq!(write_result.status(), Ok(()));
2074
2075 assert_eq!(
2076 client_completion.wait().expect("client receives echo"),
2077 b"ping".to_vec()
2078 );
2079 }
2080}