1use anyhow::Context;
2use log::{debug, error};
3use socks5_proto::{
4 Address, Command, Reply, Request as SocksRequest, Response,
5 handshake::{self, Method},
6};
7use tokio::{
8 io::{AsyncWriteExt, copy_bidirectional},
9 net::TcpStream,
10};
11
12use crate::dial::Dial;
13use crate::proto::http_connect;
14
15pub struct ProxyConnection {
16 ts: TcpStream,
17 dial: Box<dyn Dial>,
18}
19
20impl ProxyConnection {
21 pub fn new(ts: TcpStream, dial: Box<dyn Dial>) -> Self {
22 Self { ts, dial }
23 }
24
25 pub async fn handle(self) {
27 let mut first_bit = [0u8];
28 if let Err(e) = self.ts.peek(&mut first_bit).await {
29 error!("can't peek first_bit err: {e}");
30 return;
31 }
32
33 let ret = if first_bit[0] == socks5_proto::SOCKS_VERSION {
34 self.handle_socks().await
35 } else {
36 self.handle_http().await
37 };
38 if let Err(e) = ret {
39 error!("proxy handle err: {e:?}");
40 };
41 }
42
43 async fn handle_socks(mut self) -> anyhow::Result<()> {
44 debug!(
45 "socks proxy connection {:?} to {:?}",
46 self.ts.peer_addr().ok(),
47 self.ts.local_addr().ok()
48 );
49
50 let _req = handshake::Request::read_from(&mut self.ts)
51 .await
52 .context("socks handshake failed")?;
53
54 let resp = handshake::Response::new(Method::NONE);
55
56 resp.write_to(&mut self.ts)
57 .await
58 .context("socks write response failed")?;
59
60 let req = SocksRequest::read_from(&mut self.ts)
61 .await
62 .context("socks read request failed")?;
63
64 let addr = req.address;
65
66 debug!("start connect {addr}");
67 match req.command {
68 Command::Connect => {
69 let target = self.dial.dial(addr.clone()).await;
70 match target {
71 Ok(mut target) => {
72 self.socks_reply(Reply::Succeeded, Address::unspecified())
73 .await?;
74
75 if let Ok((a, b)) = copy_bidirectional(&mut self.ts, &mut target).await {
76 debug!(
77 "socks copy end for {} traffic: {}<=>{} total: {}",
78 addr,
79 a,
80 b,
81 a + b
82 );
83 }
84
85 Ok(())
86 }
87 Err(e) => {
88 self.socks_reply(Reply::HostUnreachable, Address::unspecified())
89 .await
90 .context("socks reply failed.")?;
91 Err(e).context(format!("socks dial {addr} failed ."))
92 }
93 }
94 }
95 cmd => {
96 debug!("socks unsupported command {:?}", cmd);
97 self.socks_reply(Reply::CommandNotSupported, Address::unspecified())
98 .await?;
99 Ok(())
100 }
101 }
102 }
103
104 async fn handle_http(mut self) -> anyhow::Result<()> {
105 debug!(
106 "http proxy connection {:?} to {:?}",
107 self.ts.peer_addr().ok(),
108 self.ts.local_addr().ok()
109 );
110 let buf = http_connect::read_http_request_end(&mut self.ts)
111 .await
112 .context("http proxy read http request end failed")?;
113
114 debug!(
115 "http proxy read buf: \n{}",
116 String::from_utf8_lossy(buf.as_slice())
117 );
118 match http_connect::HttpConnectRequest::parse(buf.as_slice()) {
119 Ok(req) => {
120 let addr = req.addr().clone();
121 let mut target = self
122 .dial
123 .dial(addr.clone())
124 .await
125 .context(format!("http proxy connect addr {} failed", addr))?;
126
127 if let Some(data) = req.nugget() {
128 target
129 .write_all(data.data().as_slice())
130 .await
131 .context("http proxy target write_all buf failed")?;
132 target
133 .flush()
134 .await
135 .context("http proxy flush target failed")?;
136 } else {
137 self.ts
138 .write("HTTP/1.1 200 OK\r\n\r\n".as_bytes())
139 .await
140 .context("http proxy write response failed")?;
141 }
142
143 if let Ok((a, b)) = copy_bidirectional(&mut self.ts, &mut target).await {
144 debug!(
145 "http copy end for {} traffic: {}<=>{} total: {}",
146 addr,
147 a,
148 b,
149 a + b
150 );
151 };
152 Ok(())
153 }
154 Err(e) => {
155 debug!("http proxy BAD_REQUEST");
156 self.ts
157 .write("HTTP/1.1 400 BAD_REQUEST\r\n\r\n".as_bytes())
158 .await
159 .context("http proxy write response failed")?;
160 Err(e).context("http dial failed .".to_string())
161 }
162 }
163 }
164
165 async fn socks_reply(&mut self, reply: Reply, addr: Address) -> anyhow::Result<()> {
166 let resp = Response::new(reply, addr);
167 resp.write_to(&mut self.ts)
168 .await
169 .context("scoks write reply response failed")
170 }
171}
172
173#[cfg(test)]
174mod tests {
175 use std::sync::{Arc, Mutex};
176
177 use anyhow::anyhow;
178 use async_trait::async_trait;
179 use socks5_proto::{
180 Address, Command, Reply, Request as SocksRequest, Response,
181 handshake::{Method, Request as HandshakeRequest, Response as HandshakeResponse},
182 };
183 use tokio::{
184 io::{AsyncReadExt, AsyncWriteExt, DuplexStream},
185 net::{TcpListener, TcpStream},
186 };
187
188 use super::ProxyConnection;
189 use crate::dial::{AsyncStream, Dial};
190
191 struct MockDial {
192 result: Mutex<Option<anyhow::Result<DuplexStream>>>,
193 seen_addrs: Arc<Mutex<Vec<Address>>>,
194 }
195
196 impl MockDial {
197 fn succeed(stream: DuplexStream, seen_addrs: Arc<Mutex<Vec<Address>>>) -> Self {
198 Self {
199 result: Mutex::new(Some(Ok(stream))),
200 seen_addrs,
201 }
202 }
203
204 fn fail(err: anyhow::Error, seen_addrs: Arc<Mutex<Vec<Address>>>) -> Self {
205 Self {
206 result: Mutex::new(Some(Err(err))),
207 seen_addrs,
208 }
209 }
210 }
211
212 #[async_trait]
213 impl Dial for MockDial {
214 async fn dial(&self, addr: Address) -> anyhow::Result<Box<dyn AsyncStream>> {
215 self.seen_addrs.lock().unwrap().push(addr);
216 self.result
217 .lock()
218 .unwrap()
219 .take()
220 .expect("dial should only be called once")
221 .map(|s| Box::new(s) as Box<dyn AsyncStream>)
222 }
223 }
224
225 async fn tcp_pair() -> (TcpStream, TcpStream) {
226 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
227 let addr = listener.local_addr().unwrap();
228
229 let client = TcpStream::connect(addr).await.unwrap();
230 let (server, _) = listener.accept().await.unwrap();
231
232 (server, client)
233 }
234
235 #[tokio::test]
236 async fn handle_http_connect_replies_ok_and_dials_target() {
237 let (server, mut client) = tcp_pair().await;
238 let (target, mut target_peer) = tokio::io::duplex(256);
239 let seen_addrs = Arc::new(Mutex::new(Vec::new()));
240 let dial = MockDial::succeed(target, seen_addrs.clone());
241
242 let proxy = ProxyConnection::new(server, Box::new(dial));
243 let proxy_task = tokio::spawn(async move { proxy.handle_http().await });
244 let target_task = tokio::spawn(async move {
245 let mut buf = Vec::new();
246 target_peer.read_to_end(&mut buf).await.unwrap();
247 buf
248 });
249
250 client
251 .write_all(b"CONNECT example.com:443 HTTP/1.1\r\nHost: example.com:443\r\n\r\n")
252 .await
253 .unwrap();
254
255 let mut response = [0u8; 19];
256 client.read_exact(&mut response).await.unwrap();
257 assert_eq!(&response, b"HTTP/1.1 200 OK\r\n\r\n");
258
259 client.shutdown().await.unwrap();
260 drop(client);
261
262 proxy_task.await.unwrap().unwrap();
263 assert!(target_task.await.unwrap().is_empty());
264 assert_eq!(
265 seen_addrs.lock().unwrap().as_slice(),
266 &[Address::DomainAddress(b"example.com".to_vec(), 443)]
267 );
268 }
269
270 #[tokio::test]
271 async fn handle_http_with_nugget_forwards_request_body_to_target() {
272 let (server, mut client) = tcp_pair().await;
273 let (target, mut target_peer) = tokio::io::duplex(512);
274 let seen_addrs = Arc::new(Mutex::new(Vec::new()));
275 let dial = MockDial::succeed(target, seen_addrs.clone());
276 let raw = b"GET https://upstream.example/path HTTP/1.1\r\nHost: service.internal\r\n\r\n";
277
278 let proxy = ProxyConnection::new(server, Box::new(dial));
279 let proxy_task = tokio::spawn(async move { proxy.handle_http().await });
280 let target_task = tokio::spawn(async move {
281 let mut received = vec![0; raw.len()];
282 target_peer.read_exact(&mut received).await.unwrap();
283 target_peer
284 .write_all(b"HTTP/1.1 204 No Content\r\n\r\n")
285 .await
286 .unwrap();
287 target_peer.shutdown().await.unwrap();
288 received
289 });
290
291 client.write_all(raw).await.unwrap();
292
293 let mut response = vec![0; 27];
294 client.read_exact(&mut response).await.unwrap();
295 assert_eq!(response, b"HTTP/1.1 204 No Content\r\n\r\n");
296
297 client.shutdown().await.unwrap();
298 drop(client);
299
300 proxy_task.await.unwrap().unwrap();
301 assert_eq!(target_task.await.unwrap(), raw);
302 assert_eq!(
303 seen_addrs.lock().unwrap().as_slice(),
304 &[Address::DomainAddress(b"service.internal".to_vec(), 443)]
305 );
306 }
307
308 #[tokio::test]
309 async fn handle_http_bad_request_returns_400_without_dialing() {
310 let (server, mut client) = tcp_pair().await;
311 let seen_addrs = Arc::new(Mutex::new(Vec::new()));
312 let dial = MockDial::fail(anyhow!("dial should not be called"), seen_addrs.clone());
313
314 let proxy = ProxyConnection::new(server, Box::new(dial));
315 let proxy_task = tokio::spawn(async move { proxy.handle_http().await });
316
317 client.write_all(b"BAD\r\n\r\n").await.unwrap();
318
319 let mut response = [0u8; 28];
320 client.read_exact(&mut response).await.unwrap();
321 assert_eq!(&response, b"HTTP/1.1 400 BAD_REQUEST\r\n\r\n");
322
323 client.shutdown().await.unwrap();
324 drop(client);
325
326 assert!(proxy_task.await.unwrap().is_err());
327 assert!(seen_addrs.lock().unwrap().is_empty());
328 }
329
330 #[tokio::test]
331 async fn handle_socks_connect_negotiates_and_replies_succeeded() {
332 let (server, mut client) = tcp_pair().await;
333 let (target, mut target_peer) = tokio::io::duplex(256);
334 let seen_addrs = Arc::new(Mutex::new(Vec::new()));
335 let dial = MockDial::succeed(target, seen_addrs.clone());
336
337 let proxy = ProxyConnection::new(server, Box::new(dial));
338 let proxy_task = tokio::spawn(async move { proxy.handle_socks().await });
339 let target_task = tokio::spawn(async move {
340 let mut buf = Vec::new();
341 target_peer.read_to_end(&mut buf).await.unwrap();
342 buf
343 });
344
345 HandshakeRequest::new(vec![Method::NONE])
346 .write_to(&mut client)
347 .await
348 .unwrap();
349 let handshake = HandshakeResponse::read_from(&mut client).await.unwrap();
350 assert_eq!(handshake.method, Method::NONE);
351
352 SocksRequest::new(
353 Command::Connect,
354 Address::DomainAddress(b"example.com".to_vec(), 1080),
355 )
356 .write_to(&mut client)
357 .await
358 .unwrap();
359
360 let response = Response::read_from(&mut client).await.unwrap();
361 assert_eq!(response.reply, Reply::Succeeded);
362
363 client.shutdown().await.unwrap();
364 drop(client);
365
366 proxy_task.await.unwrap().unwrap();
367 assert!(target_task.await.unwrap().is_empty());
368 assert_eq!(
369 seen_addrs.lock().unwrap().as_slice(),
370 &[Address::DomainAddress(b"example.com".to_vec(), 1080)]
371 );
372 }
373}