1use std::{
5 io,
6 pin::Pin,
7 task::{Context, Poll, ready},
8};
9
10use futures::{AsyncBufRead, AsyncWrite};
11use pin_project::pin_project;
12
13use crate::{
14 arc_io_result::{ArcIoResult, wrap_error},
15 copy_buf::poll_copy_r_to_w,
16 eof::EofStrategy,
17 fuse_buf_reader::FuseBufReader,
18};
19
20pub fn copy_buf_bidirectional<A, B, AE, BE>(
44 stream_a: A,
45 stream_b: B,
46 on_a_eof: AE,
47 on_b_eof: BE,
48) -> CopyBufBidirectional<A, B, AE, BE>
49where
50 A: AsyncBufRead + AsyncWrite,
51 B: AsyncBufRead + AsyncWrite,
52 AE: EofStrategy<B>,
53 BE: EofStrategy<A>,
54{
55 CopyBufBidirectional {
56 stream_a: FuseBufReader::new(stream_a),
57 stream_b: FuseBufReader::new(stream_b),
58 on_a_eof,
59 on_b_eof,
60 copied_a_to_b: 0,
61 copied_b_to_a: 0,
62 a_to_b_status: DirectionStatus::Copying,
63 b_to_a_status: DirectionStatus::Copying,
64 }
65}
66
67#[derive(Debug)]
75#[pin_project]
76#[must_use = "futures do nothing unless you `.await` or poll them"]
77pub struct CopyBufBidirectional<A, B, AE, BE> {
78 #[pin]
80 stream_a: FuseBufReader<A>,
81
82 #[pin]
84 stream_b: FuseBufReader<B>,
85
86 #[pin]
88 on_a_eof: AE,
89
90 #[pin]
92 on_b_eof: BE,
93
94 copied_a_to_b: u64,
96 copied_b_to_a: u64,
98
99 a_to_b_status: DirectionStatus,
101
102 b_to_a_status: DirectionStatus,
104}
105
106impl<A, B, AE, BE> CopyBufBidirectional<A, B, AE, BE> {
107 pub fn into_inner(self) -> (A, B) {
109 (self.stream_a.into_inner(), self.stream_b.into_inner())
110 }
111}
112
113impl<A, B, AE, BE> Future for CopyBufBidirectional<A, B, AE, BE>
114where
115 A: AsyncBufRead + AsyncWrite,
116 B: AsyncBufRead + AsyncWrite,
117 AE: EofStrategy<B>,
118 BE: EofStrategy<A>,
119{
120 type Output = io::Result<(u64, u64)>;
121
122 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
123 use DirectionStatus::*;
124
125 let mut this = self.project();
126
127 if *this.a_to_b_status != DirectionStatus::Done {
128 let _ignore_completion = one_direction(
129 cx,
130 this.stream_a.as_mut(),
131 this.stream_b.as_mut(),
132 this.on_a_eof,
133 this.copied_a_to_b,
134 this.a_to_b_status,
135 )
136 .map_err(|e| wrap_error(&e))?;
137 }
138
139 if *this.b_to_a_status != DirectionStatus::Done {
140 let _ignore_completion = one_direction(
141 cx,
142 this.stream_b.as_mut(),
143 this.stream_a.as_mut(),
144 this.on_b_eof,
145 this.copied_b_to_a,
146 this.b_to_a_status,
147 )
148 .map_err(|e| wrap_error(&e))?;
149 }
150
151 if (*this.a_to_b_status, *this.b_to_a_status) == (Done, Done) {
152 Poll::Ready(Ok((*this.copied_a_to_b, *this.copied_b_to_a)))
153 } else {
154 Poll::Pending
155 }
156 }
157}
158
159#[derive(Clone, Copy, PartialEq, Eq, Debug)]
161enum DirectionStatus {
162 Copying,
164
165 SendingEof,
167
168 Done,
170}
171
172fn one_direction<A, B, AE>(
174 cx: &mut Context<'_>,
175 r: Pin<&mut FuseBufReader<A>>,
176 mut w: Pin<&mut FuseBufReader<B>>,
177 eof_strategy: Pin<&mut AE>,
178 n_copied: &mut u64,
179 status: &mut DirectionStatus,
180) -> Poll<ArcIoResult<()>>
181where
182 A: AsyncBufRead,
183 B: AsyncWrite,
184 AE: EofStrategy<B>,
185{
186 use DirectionStatus::*;
187
188 if *status == Copying {
189 let () = ready!(poll_copy_r_to_w(cx, r, w.as_mut(), n_copied, false))?;
190 *status = SendingEof;
191 }
192
193 if *status == SendingEof {
194 let () = ready!(eof_strategy.poll_send_eof(cx, w.get_pin_mut()))?;
195 *status = Done;
196 }
197
198 assert_eq!(*status, Done);
199 Poll::Ready(Ok(()))
200}
201
202#[cfg(test)]
203mod test {
204 #![allow(clippy::bool_assert_comparison)]
206 #![allow(clippy::clone_on_copy)]
207 #![allow(clippy::dbg_macro)]
208 #![allow(clippy::mixed_attributes_style)]
209 #![allow(clippy::print_stderr)]
210 #![allow(clippy::print_stdout)]
211 #![allow(clippy::single_char_pattern)]
212 #![allow(clippy::unwrap_used)]
213 #![allow(clippy::unchecked_time_subtraction)]
214 #![allow(clippy::useless_vec)]
215 #![allow(clippy::needless_pass_by_value)]
216 use super::*;
219 use crate::{eof, test::RWPair};
220
221 use futures::{
222 AsyncBufReadExt,
223 io::{BufReader, BufWriter, Cursor},
224 };
225 use tor_rtcompat::SpawnExt as _;
226 use tor_rtmock::{MockRuntime, io::stream_pair};
227
228 #[allow(clippy::type_complexity)]
230 fn cursor_stream(init_data: &[u8]) -> BufReader<RWPair<Cursor<Vec<u8>>, Cursor<Vec<u8>>>> {
231 BufReader::new(RWPair(
232 Cursor::new(init_data.to_vec()),
233 Cursor::new(Vec::new()),
234 ))
235 }
236
237 async fn test_transfer_cursor(data_1: &[u8], data_2: &[u8]) {
238 let mut s1 = cursor_stream(data_1);
239 let mut s2 = cursor_stream(data_2);
240
241 let (t1, t2) = copy_buf_bidirectional(&mut s1, &mut s2, eof::Close, eof::Close)
242 .await
243 .unwrap();
244 assert_eq!(t1, data_1.len() as u64);
245 assert_eq!(t2, data_2.len() as u64);
246 let out1 = s1.into_inner().1.into_inner();
247 let out2 = s2.into_inner().1.into_inner();
248 assert_eq!(&out1[..], data_2);
249 assert_eq!(&out2[..], data_1);
250 }
251
252 async fn test_transfer_streams(rt: &MockRuntime, data_1: &[u8], data_2: &[u8]) {
253 let mut s1 = cursor_stream(data_1);
254 let (s2, s3) = stream_pair();
255 let mut s4 = cursor_stream(data_2);
256
257 let h1 = rt
258 .spawn_with_handle(async move {
259 let r = copy_buf_bidirectional(&mut s1, BufReader::new(s2), eof::Close, eof::Close)
260 .await;
261 (r, s1.into_inner().1.into_inner())
262 })
263 .unwrap();
264 let h2 = rt
265 .spawn_with_handle(async move {
266 let r = copy_buf_bidirectional(BufReader::new(s3), &mut s4, eof::Close, eof::Close)
267 .await;
268 (r, s4.into_inner().1.into_inner())
269 })
270 .unwrap();
271 let (r1, buf1) = h1.await;
272 let (r2, buf2) = h2.await;
273
274 assert_eq!(r1.unwrap(), (data_1.len() as u64, data_2.len() as u64));
275 assert_eq!(r2.unwrap(), (data_1.len() as u64, data_2.len() as u64));
276 assert_eq!(&buf1, data_2);
277 assert_eq!(&buf2, data_1);
278 }
279
280 fn test_transfer(data_1: &[u8], data_2: &[u8]) {
281 MockRuntime::test_with_various(async |rt| {
282 test_transfer_cursor(data_1, data_2).await;
283 test_transfer_streams(&rt, data_1, data_2).await;
284 });
285 }
286
287 fn big(x: u8) -> Vec<u8> {
288 (1..=x).cycle().take(1_234_567).collect()
289 }
290
291 #[test]
292 fn transfer_empty() {
293 test_transfer(&[], &[]);
294 }
295
296 #[test]
297 fn transfer_empty_small() {
298 test_transfer(&[], b"hello world");
299 }
300
301 #[test]
302 fn transfer_small() {
303 test_transfer(b"hola mundo", b"hello world");
304 }
305
306 #[test]
307 fn transfer_huge() {
308 let big1 = big(79);
309 let big2 = big(81);
310 test_transfer(&big1, &big2);
311 }
312
313 #[test]
314 fn interactive_protocol() {
315 use futures::io::AsyncWriteExt as _;
316 MockRuntime::test_with_various(async |rt| {
320 let (s1, s2) = stream_pair();
321 let (s3, s4) = stream_pair();
322
323 let mut s1 = BufReader::new(s1);
326 let s2 = BufReader::new(BufWriter::with_capacity(1024, s2));
327 let s3 = BufReader::new(BufWriter::with_capacity(1024, s3));
328 let mut s4 = BufReader::new(s4);
329
330 let h1 = rt
346 .spawn_with_handle(async move {
347 let mut buf = String::new();
348 let mut num: u32 = 1;
349
350 loop {
351 s1.write_all(format!("{num}\n").as_bytes()).await?;
352 s1.flush().await?;
353
354 let written = num;
355
356 let n_bytes_read = s1.read_line(&mut buf).await?;
357 if n_bytes_read == 0 {
358 break;
359 }
360 num = buf.trim_ascii().parse().unwrap();
361 buf.clear();
362 assert_eq!(num, written + 1);
363
364 if num >= 100 {
365 break;
366 }
367 num += 1;
368 }
369
370 s1.close().await?;
371
372 Ok::<u32, io::Error>(num)
373 })
374 .unwrap();
375
376 let h2 = rt
378 .spawn_with_handle(copy_buf_bidirectional(s2, s3, eof::Close, eof::Close))
379 .unwrap();
380
381 let h3 = rt
383 .spawn_with_handle(async move {
384 let mut buf = String::new();
385 let mut last_written = None;
386
387 loop {
388 let n_bytes_read = s4.read_line(&mut buf).await?;
389 if n_bytes_read == 0 {
390 break;
391 }
392 let num: u32 = buf.trim_ascii().parse().unwrap();
393 buf.clear();
394 if let Some(last) = last_written {
395 assert_eq!(num, last + 1);
396 }
397
398 let num = num + 1;
399 s4.write_all(format!("{num}\n").as_bytes()).await?;
400 s4.flush().await?;
401 last_written = Some(num);
402 }
403 Ok::<_, io::Error>(())
404 })
405 .unwrap();
406
407 let outcome1 = h1.await;
408 let outcome2 = h2.await;
409 let outcome3 = h3.await;
410
411 assert_eq!(outcome1.unwrap(), 100);
412 let (_, _) = outcome2.unwrap();
413 let () = outcome3.unwrap();
414 });
415 }
416}