Skip to main content

futures_copy/
copy_buf_bidi.rs

1//! Functionality to copy bidirectionally between two streams
2//! that implement `AsyncBufRead` and`AsyncWrite`.
3
4use std::{
5    io,
6    pin::Pin,
7    task::{Context, Poll, ready},
8};
9
10use futures::{AsyncBufRead, AsyncWrite};
11use pin_project::pin_project;
12
13use crate::{
14    arc_io_result::{ArcIoResult, wrap_error},
15    copy_buf::poll_copy_r_to_w,
16    eof::EofStrategy,
17    fuse_buf_reader::FuseBufReader,
18};
19
20/// Return a future to copies bytes from `stream_a` to `stream_b`,
21/// and from `stream_b` to `stream_a`.
22///
23/// The future makes sure that
24/// if a stream pauses (returns Pending),
25/// all as-yet-received bytes are still flushed to the other stream.
26///
27/// If an EOF is read from `stream_a`,
28/// the future uses `on_a_eof` to report the EOF to `stream_b`.
29/// Similarly, if an EOF is read from  `stream_b`,
30/// the future uses `on_b_eof` to report the EOF to `stream_a`.
31///
32/// The future will continue running until either an error has occurred
33/// (in which case it yields an error),
34/// or until both streams have returned an EOF as readers
35/// and have both been flushed as writers
36/// (in which case it yields a tuple of the number of bytes copied from a to b,
37/// and the number of bytes copied from b to a.)
38///
39/// # Limitations
40///
41/// See the crate-level documentation for
42/// [discussion of this function's limitations](crate#Limitations).
43pub fn copy_buf_bidirectional<A, B, AE, BE>(
44    stream_a: A,
45    stream_b: B,
46    on_a_eof: AE,
47    on_b_eof: BE,
48) -> CopyBufBidirectional<A, B, AE, BE>
49where
50    A: AsyncBufRead + AsyncWrite,
51    B: AsyncBufRead + AsyncWrite,
52    AE: EofStrategy<B>,
53    BE: EofStrategy<A>,
54{
55    CopyBufBidirectional {
56        stream_a: FuseBufReader::new(stream_a),
57        stream_b: FuseBufReader::new(stream_b),
58        on_a_eof,
59        on_b_eof,
60        copied_a_to_b: 0,
61        copied_b_to_a: 0,
62        a_to_b_status: DirectionStatus::Copying,
63        b_to_a_status: DirectionStatus::Copying,
64    }
65}
66
67/// A future returned by [`copy_buf_bidirectional`].
68//
69// Note to the reader: You might think it's a good idea to have two separate CopyBuf futures here.
70// That won't work, though, since each one would need to own both `stream_a` and `stream_b`.
71// We could use `split` to share the streams, but that would introduce needless locking overhead.
72//
73// Instead, we implement the shared functionality via poll_copy_r_to_w.
74#[derive(Debug)]
75#[pin_project]
76#[must_use = "futures do nothing unless you `.await` or poll them"]
77pub struct CopyBufBidirectional<A, B, AE, BE> {
78    /// The first stream.
79    #[pin]
80    stream_a: FuseBufReader<A>,
81
82    /// The second stream.
83    #[pin]
84    stream_b: FuseBufReader<B>,
85
86    /// An [`EofStrategy`] to use when `stream_a` reaches EOF.
87    #[pin]
88    on_a_eof: AE,
89
90    /// An [`EofStrategy`] to use when `stream_b` reaches EOF.
91    #[pin]
92    on_b_eof: BE,
93
94    /// The number of bytes from `a` written onto `b` so far.
95    copied_a_to_b: u64,
96    /// The number of bytes from `b` written onto `a` so far.
97    copied_b_to_a: u64,
98
99    /// The current status of copying from `a` to `b`.
100    a_to_b_status: DirectionStatus,
101
102    /// The current status of copying from `b` to `a`.
103    b_to_a_status: DirectionStatus,
104}
105
106impl<A, B, AE, BE> CopyBufBidirectional<A, B, AE, BE> {
107    /// Consume this CopyBufBirectional future, and return the underlying streams.
108    pub fn into_inner(self) -> (A, B) {
109        (self.stream_a.into_inner(), self.stream_b.into_inner())
110    }
111}
112
113impl<A, B, AE, BE> Future for CopyBufBidirectional<A, B, AE, BE>
114where
115    A: AsyncBufRead + AsyncWrite,
116    B: AsyncBufRead + AsyncWrite,
117    AE: EofStrategy<B>,
118    BE: EofStrategy<A>,
119{
120    type Output = io::Result<(u64, u64)>;
121
122    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
123        use DirectionStatus::*;
124
125        let mut this = self.project();
126
127        if *this.a_to_b_status != DirectionStatus::Done {
128            let _ignore_completion = one_direction(
129                cx,
130                this.stream_a.as_mut(),
131                this.stream_b.as_mut(),
132                this.on_a_eof,
133                this.copied_a_to_b,
134                this.a_to_b_status,
135            )
136            .map_err(|e| wrap_error(&e))?;
137        }
138
139        if *this.b_to_a_status != DirectionStatus::Done {
140            let _ignore_completion = one_direction(
141                cx,
142                this.stream_b.as_mut(),
143                this.stream_a.as_mut(),
144                this.on_b_eof,
145                this.copied_b_to_a,
146                this.b_to_a_status,
147            )
148            .map_err(|e| wrap_error(&e))?;
149        }
150
151        if (*this.a_to_b_status, *this.b_to_a_status) == (Done, Done) {
152            Poll::Ready(Ok((*this.copied_a_to_b, *this.copied_b_to_a)))
153        } else {
154            Poll::Pending
155        }
156    }
157}
158
159/// A possible status for copying in a single direction.
160#[derive(Clone, Copy, PartialEq, Eq, Debug)]
161enum DirectionStatus {
162    /// Copying data: we have not yet reached an EOF.
163    Copying,
164
165    /// Reached EOF: using an [`EofStrategy`] to propagate the EOF to the writer.
166    SendingEof,
167
168    /// EOF sent: Nothing more to do.
169    Done,
170}
171
172/// Try to make progress copying data in a single data, and propagating the EOF.
173fn one_direction<A, B, AE>(
174    cx: &mut Context<'_>,
175    r: Pin<&mut FuseBufReader<A>>,
176    mut w: Pin<&mut FuseBufReader<B>>,
177    eof_strategy: Pin<&mut AE>,
178    n_copied: &mut u64,
179    status: &mut DirectionStatus,
180) -> Poll<ArcIoResult<()>>
181where
182    A: AsyncBufRead,
183    B: AsyncWrite,
184    AE: EofStrategy<B>,
185{
186    use DirectionStatus::*;
187
188    if *status == Copying {
189        let () = ready!(poll_copy_r_to_w(cx, r, w.as_mut(), n_copied, false))?;
190        *status = SendingEof;
191    }
192
193    if *status == SendingEof {
194        let () = ready!(eof_strategy.poll_send_eof(cx, w.get_pin_mut()))?;
195        *status = Done;
196    }
197
198    assert_eq!(*status, Done);
199    Poll::Ready(Ok(()))
200}
201
202#[cfg(test)]
203mod test {
204    // @@ begin test lint list maintained by maint/add_warning @@
205    #![allow(clippy::bool_assert_comparison)]
206    #![allow(clippy::clone_on_copy)]
207    #![allow(clippy::dbg_macro)]
208    #![allow(clippy::mixed_attributes_style)]
209    #![allow(clippy::print_stderr)]
210    #![allow(clippy::print_stdout)]
211    #![allow(clippy::single_char_pattern)]
212    #![allow(clippy::unwrap_used)]
213    #![allow(clippy::unchecked_time_subtraction)]
214    #![allow(clippy::useless_vec)]
215    #![allow(clippy::needless_pass_by_value)]
216    //! <!-- @@ end test lint list maintained by maint/add_warning @@ -->
217
218    use super::*;
219    use crate::{eof, test::RWPair};
220
221    use futures::{
222        AsyncBufReadExt,
223        io::{BufReader, BufWriter, Cursor},
224    };
225    use tor_rtcompat::SpawnExt as _;
226    use tor_rtmock::{MockRuntime, io::stream_pair};
227
228    /// Return a stream implemented with a pair of Vec-backed cursors.
229    #[allow(clippy::type_complexity)]
230    fn cursor_stream(init_data: &[u8]) -> BufReader<RWPair<Cursor<Vec<u8>>, Cursor<Vec<u8>>>> {
231        BufReader::new(RWPair(
232            Cursor::new(init_data.to_vec()),
233            Cursor::new(Vec::new()),
234        ))
235    }
236
237    async fn test_transfer_cursor(data_1: &[u8], data_2: &[u8]) {
238        let mut s1 = cursor_stream(data_1);
239        let mut s2 = cursor_stream(data_2);
240
241        let (t1, t2) = copy_buf_bidirectional(&mut s1, &mut s2, eof::Close, eof::Close)
242            .await
243            .unwrap();
244        assert_eq!(t1, data_1.len() as u64);
245        assert_eq!(t2, data_2.len() as u64);
246        let out1 = s1.into_inner().1.into_inner();
247        let out2 = s2.into_inner().1.into_inner();
248        assert_eq!(&out1[..], data_2);
249        assert_eq!(&out2[..], data_1);
250    }
251
252    async fn test_transfer_streams(rt: &MockRuntime, data_1: &[u8], data_2: &[u8]) {
253        let mut s1 = cursor_stream(data_1);
254        let (s2, s3) = stream_pair();
255        let mut s4 = cursor_stream(data_2);
256
257        let h1 = rt
258            .spawn_with_handle(async move {
259                let r = copy_buf_bidirectional(&mut s1, BufReader::new(s2), eof::Close, eof::Close)
260                    .await;
261                (r, s1.into_inner().1.into_inner())
262            })
263            .unwrap();
264        let h2 = rt
265            .spawn_with_handle(async move {
266                let r = copy_buf_bidirectional(BufReader::new(s3), &mut s4, eof::Close, eof::Close)
267                    .await;
268                (r, s4.into_inner().1.into_inner())
269            })
270            .unwrap();
271        let (r1, buf1) = h1.await;
272        let (r2, buf2) = h2.await;
273
274        assert_eq!(r1.unwrap(), (data_1.len() as u64, data_2.len() as u64));
275        assert_eq!(r2.unwrap(), (data_1.len() as u64, data_2.len() as u64));
276        assert_eq!(&buf1, data_2);
277        assert_eq!(&buf2, data_1);
278    }
279
280    fn test_transfer(data_1: &[u8], data_2: &[u8]) {
281        MockRuntime::test_with_various(async |rt| {
282            test_transfer_cursor(data_1, data_2).await;
283            test_transfer_streams(&rt, data_1, data_2).await;
284        });
285    }
286
287    fn big(x: u8) -> Vec<u8> {
288        (1..=x).cycle().take(1_234_567).collect()
289    }
290
291    #[test]
292    fn transfer_empty() {
293        test_transfer(&[], &[]);
294    }
295
296    #[test]
297    fn transfer_empty_small() {
298        test_transfer(&[], b"hello world");
299    }
300
301    #[test]
302    fn transfer_small() {
303        test_transfer(b"hola mundo", b"hello world");
304    }
305
306    #[test]
307    fn transfer_huge() {
308        let big1 = big(79);
309        let big2 = big(81);
310        test_transfer(&big1, &big2);
311    }
312
313    #[test]
314    fn interactive_protocol() {
315        use futures::io::AsyncWriteExt as _;
316        // Test our flush behavior by relaying traffic between a pair of communicators that
317        // don't say anything until they get a message.
318
319        MockRuntime::test_with_various(async |rt| {
320            let (s1, s2) = stream_pair();
321            let (s3, s4) = stream_pair();
322
323            // Using BufWriter here means that unless we propagate the flush correctly,
324            // flushing won't happen soon enough to cause a reply.
325            let mut s1 = BufReader::new(s1);
326            let s2 = BufReader::new(BufWriter::with_capacity(1024, s2));
327            let s3 = BufReader::new(BufWriter::with_capacity(1024, s3));
328            let mut s4 = BufReader::new(s4);
329
330            // That's a lot of streams!  Here's how they all connect:
331            //
332            // Task 1 <--> s1  <-Rt-> s2 <-> Task 2 <--> s3 <-Rt-> s4 <--> Task 3
333            //
334            // In other words, s1 and s2 are automatically connected under the hood by
335            // the MockRuntime, as are s3 and s4.  Task 1 reads and writes from s1.
336            // Task 2 tests copy_buf_bidirectional by relaying between s2 and s3.
337            // And Task 3 reads and writes to s4.
338            //
339            // Thus task 1 and task 3 can only communicate with one another if
340            // task 2 (and copy_buf_bidirectional) do their job.
341
342            // Task 1:
343            // Write a number starting with 1, then read numbers and write back 1 more.
344            // Continue until you read a number >= 100.
345            let h1 = rt
346                .spawn_with_handle(async move {
347                    let mut buf = String::new();
348                    let mut num: u32 = 1;
349
350                    loop {
351                        s1.write_all(format!("{num}\n").as_bytes()).await?;
352                        s1.flush().await?;
353
354                        let written = num;
355
356                        let n_bytes_read = s1.read_line(&mut buf).await?;
357                        if n_bytes_read == 0 {
358                            break;
359                        }
360                        num = buf.trim_ascii().parse().unwrap();
361                        buf.clear();
362                        assert_eq!(num, written + 1);
363
364                        if num >= 100 {
365                            break;
366                        }
367                        num += 1;
368                    }
369
370                    s1.close().await?;
371
372                    Ok::<u32, io::Error>(num)
373                })
374                .unwrap();
375
376            // Task 2: Use copy_buf_bidirectional to relay traffic.
377            let h2 = rt
378                .spawn_with_handle(copy_buf_bidirectional(s2, s3, eof::Close, eof::Close))
379                .unwrap();
380
381            // Task 3: Forever: read a number on a line, and write back 1 more.
382            let h3 = rt
383                .spawn_with_handle(async move {
384                    let mut buf = String::new();
385                    let mut last_written = None;
386
387                    loop {
388                        let n_bytes_read = s4.read_line(&mut buf).await?;
389                        if n_bytes_read == 0 {
390                            break;
391                        }
392                        let num: u32 = buf.trim_ascii().parse().unwrap();
393                        buf.clear();
394                        if let Some(last) = last_written {
395                            assert_eq!(num, last + 1);
396                        }
397
398                        let num = num + 1;
399                        s4.write_all(format!("{num}\n").as_bytes()).await?;
400                        s4.flush().await?;
401                        last_written = Some(num);
402                    }
403                    Ok::<_, io::Error>(())
404                })
405                .unwrap();
406
407            let outcome1 = h1.await;
408            let outcome2 = h2.await;
409            let outcome3 = h3.await;
410
411            assert_eq!(outcome1.unwrap(), 100);
412            let (_, _) = outcome2.unwrap();
413            let () = outcome3.unwrap();
414        });
415    }
416}