rusty_sockslib/
copy_pump.rs1use std::time::Duration;
2
3use futures::{future::Either, pin_mut};
4use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
5use tokio::net::TcpStream;
6use tokio::time::timeout;
7
8use crate::helpers::{IntoError, Res};
9
10pub struct CopyPump {
11 client_socket: TcpStream,
12 endpoint_socket: TcpStream,
13 read_timeout: u64,
14}
15
16impl CopyPump {
17 pub fn from(client_socket: TcpStream, endpoint_socket: TcpStream, read_timeout: u64) -> Self {
18 CopyPump {
19 client_socket,
20 endpoint_socket,
21 read_timeout,
22 }
23 }
24
25 pub async fn start(self) -> Res<()> {
26 self.run_pumps_as_copy().await
27 }
28
29 async fn run_pumps_as_copy(self) -> Res<()> {
30 let (mut client_socket_read, mut client_socket_write) = self.client_socket.into_split();
31 let (mut endpoint_socket_read, mut endpoint_socket_write) = self.endpoint_socket.into_split();
32
33 let idle = match self.read_timeout {
38 0 => None,
39 ms => Some(Duration::from_millis(ms)),
40 };
41
42 let pump_up = Self::pump(&mut client_socket_read, &mut endpoint_socket_write, idle);
43 let pump_down = Self::pump(&mut endpoint_socket_read, &mut client_socket_write, idle);
44
45 pin_mut!(pump_up);
46 pin_mut!(pump_down);
47
48 match futures::future::select(pump_up, pump_down).await {
51 Either::Left((result, _)) | Either::Right((result, _)) => result,
52 }
53 }
54
55 async fn pump<R, W>(from: &mut R, to: &mut W, idle: Option<Duration>) -> Res<()>
59 where
60 R: AsyncRead + Unpin,
61 W: AsyncWrite + Unpin,
62 {
63 let mut buffer = [0u8; 16 * 1024];
64
65 loop {
66 let read = match idle {
67 Some(duration) => match timeout(duration, from.read(&mut buffer)).await {
68 Ok(result) => result?,
69 Err(_) => return "Idle timeout.".into_error(),
70 },
71 None => from.read(&mut buffer).await?,
72 };
73
74 if read == 0 {
76 return Ok(());
77 }
78
79 to.write_all(&buffer[..read]).await?;
80 to.flush().await?;
81 }
82 }
83}
84
85#[cfg(test)]
86mod tests {
87 use super::CopyPump;
88 use std::time::Duration;
89 use tokio::io::{AsyncReadExt, AsyncWriteExt, duplex};
90 use tokio::time::{sleep, timeout};
91
92 #[tokio::test]
95 async fn idle_timeout_resets_on_activity() {
96 let idle = Some(Duration::from_millis(250));
97
98 let (mut src, mut from) = duplex(256);
100 let (mut to, mut drain) = duplex(256);
101
102 let writer = async move {
105 for _ in 0..10 {
106 src.write_all(b"x").await.unwrap();
107 src.flush().await.unwrap();
108 sleep(Duration::from_millis(50)).await;
109 }
110 drop(src);
112 };
113
114 let reader = async move {
115 let mut buf = [0u8; 16];
116 let mut total = 0;
117 while total < 10 {
118 match drain.read(&mut buf).await {
119 Ok(0) | Err(_) => break,
120 Ok(n) => total += n,
121 }
122 }
123 total
124 };
125
126 let pump = CopyPump::pump(&mut from, &mut to, idle);
127
128 let (pump_result, (), received) = timeout(Duration::from_secs(5), async { tokio::join!(pump, writer, reader) })
131 .await
132 .expect("pump + driver should finish well within 5s");
133
134 assert!(pump_result.is_ok(), "active connection was killed: {:?}", pump_result.err());
135 assert_eq!(received, 10, "all bytes should have been pumped through");
136 }
137
138 #[tokio::test]
140 async fn idle_timeout_fires_when_silent() {
141 let idle = Some(Duration::from_millis(100));
142
143 let (_src, mut from) = duplex(64); let (mut to, _drain) = duplex(64);
145
146 let result = timeout(Duration::from_secs(2), CopyPump::pump(&mut from, &mut to, idle))
147 .await
148 .expect("pump should give up around the idle window, well before 2s");
149
150 assert!(result.is_err(), "silent connection should have hit the idle timeout");
151 }
152
153 #[tokio::test]
155 async fn disabled_idle_timeout_never_fires() {
156 let (_src, mut from) = duplex(64);
157 let (mut to, _drain) = duplex(64);
158
159 let outcome = timeout(Duration::from_millis(300), CopyPump::pump(&mut from, &mut to, None)).await;
162
163 assert!(outcome.is_err(), "with idle disabled the pump must keep waiting, not return");
164 }
165}