Skip to main content

borer_core/
proxy.rs

1use anyhow::Context;
2use log::{debug, error};
3use socks5_proto::{
4    Address, Command, Reply, Request as SocksRequest, Response,
5    handshake::{self, Method},
6};
7use tokio::{
8    io::{AsyncWriteExt, copy_bidirectional},
9    net::TcpStream,
10};
11
12use crate::dial::Dial;
13use crate::proto::http_connect;
14
15pub struct ProxyConnection {
16    ts: TcpStream,
17    dial: Box<dyn Dial>,
18}
19
20impl ProxyConnection {
21    pub fn new(ts: TcpStream, dial: Box<dyn Dial>) -> Self {
22        Self { ts, dial }
23    }
24
25    /// Detect the inbound protocol and proxy the connection until either side closes.
26    pub async fn handle(self) {
27        let mut first_bit = [0u8];
28        if let Err(e) = self.ts.peek(&mut first_bit).await {
29            error!("can't peek first_bit err: {e}");
30            return;
31        }
32
33        let ret = if first_bit[0] == socks5_proto::SOCKS_VERSION {
34            self.handle_socks().await
35        } else {
36            self.handle_http().await
37        };
38        if let Err(e) = ret {
39            error!("proxy handle err: {e:?}");
40        };
41    }
42
43    async fn handle_socks(mut self) -> anyhow::Result<()> {
44        debug!(
45            "socks proxy connection {:?} to {:?}",
46            self.ts.peer_addr().ok(),
47            self.ts.local_addr().ok()
48        );
49
50        let _req = handshake::Request::read_from(&mut self.ts)
51            .await
52            .context("socks handshake failed")?;
53
54        let resp = handshake::Response::new(Method::NONE);
55
56        resp.write_to(&mut self.ts)
57            .await
58            .context("socks write response failed")?;
59
60        let req = SocksRequest::read_from(&mut self.ts)
61            .await
62            .context("socks read request failed")?;
63
64        let addr = req.address;
65
66        debug!("start connect {addr}");
67        match req.command {
68            Command::Connect => {
69                let target = self.dial.dial(addr.clone()).await;
70                match target {
71                    Ok(mut target) => {
72                        self.socks_reply(Reply::Succeeded, Address::unspecified())
73                            .await?;
74
75                        if let Ok((a, b)) = copy_bidirectional(&mut self.ts, &mut target).await {
76                            debug!(
77                                "socks copy end for {} traffic: {}<=>{} total: {}",
78                                addr,
79                                a,
80                                b,
81                                a + b
82                            );
83                        }
84
85                        Ok(())
86                    }
87                    Err(e) => {
88                        self.socks_reply(Reply::HostUnreachable, Address::unspecified())
89                            .await
90                            .context("socks reply failed.")?;
91                        Err(e).context(format!("socks dial {addr} failed ."))
92                    }
93                }
94            }
95            cmd => {
96                debug!("socks unsupported command {:?}", cmd);
97                self.socks_reply(Reply::CommandNotSupported, Address::unspecified())
98                    .await?;
99                Ok(())
100            }
101        }
102    }
103
104    async fn handle_http(mut self) -> anyhow::Result<()> {
105        debug!(
106            "http proxy connection {:?} to {:?}",
107            self.ts.peer_addr().ok(),
108            self.ts.local_addr().ok()
109        );
110        let buf = http_connect::read_http_request_end(&mut self.ts)
111            .await
112            .context("http proxy read http request end failed")?;
113
114        debug!(
115            "http proxy read buf: \n{}",
116            String::from_utf8_lossy(buf.as_slice())
117        );
118        match http_connect::HttpConnectRequest::parse(buf.as_slice()) {
119            Ok(req) => {
120                let addr = req.addr().clone();
121                let mut target = self
122                    .dial
123                    .dial(addr.clone())
124                    .await
125                    .context(format!("http proxy connect addr {} failed", addr))?;
126
127                if let Some(data) = req.nugget() {
128                    target
129                        .write_all(data.data().as_slice())
130                        .await
131                        .context("http proxy target write_all buf failed")?;
132                    target
133                        .flush()
134                        .await
135                        .context("http proxy flush target failed")?;
136                } else {
137                    self.ts
138                        .write("HTTP/1.1 200 OK\r\n\r\n".as_bytes())
139                        .await
140                        .context("http proxy write response failed")?;
141                }
142
143                if let Ok((a, b)) = copy_bidirectional(&mut self.ts, &mut target).await {
144                    debug!(
145                        "http copy end for {} traffic: {}<=>{} total: {}",
146                        addr,
147                        a,
148                        b,
149                        a + b
150                    );
151                };
152                Ok(())
153            }
154            Err(e) => {
155                debug!("http proxy BAD_REQUEST");
156                self.ts
157                    .write("HTTP/1.1 400 BAD_REQUEST\r\n\r\n".as_bytes())
158                    .await
159                    .context("http proxy write response failed")?;
160                Err(e).context("http dial failed .".to_string())
161            }
162        }
163    }
164
165    async fn socks_reply(&mut self, reply: Reply, addr: Address) -> anyhow::Result<()> {
166        let resp = Response::new(reply, addr);
167        resp.write_to(&mut self.ts)
168            .await
169            .context("scoks write reply response failed")
170    }
171}
172
173#[cfg(test)]
174mod tests {
175    use std::sync::{Arc, Mutex};
176
177    use anyhow::anyhow;
178    use async_trait::async_trait;
179    use socks5_proto::{
180        Address, Command, Reply, Request as SocksRequest, Response,
181        handshake::{Method, Request as HandshakeRequest, Response as HandshakeResponse},
182    };
183    use tokio::{
184        io::{AsyncReadExt, AsyncWriteExt, DuplexStream},
185        net::{TcpListener, TcpStream},
186    };
187
188    use super::ProxyConnection;
189    use crate::dial::{AsyncStream, Dial};
190
191    struct MockDial {
192        result: Mutex<Option<anyhow::Result<DuplexStream>>>,
193        seen_addrs: Arc<Mutex<Vec<Address>>>,
194    }
195
196    impl MockDial {
197        fn succeed(stream: DuplexStream, seen_addrs: Arc<Mutex<Vec<Address>>>) -> Self {
198            Self {
199                result: Mutex::new(Some(Ok(stream))),
200                seen_addrs,
201            }
202        }
203
204        fn fail(err: anyhow::Error, seen_addrs: Arc<Mutex<Vec<Address>>>) -> Self {
205            Self {
206                result: Mutex::new(Some(Err(err))),
207                seen_addrs,
208            }
209        }
210    }
211
212    #[async_trait]
213    impl Dial for MockDial {
214        async fn dial(&self, addr: Address) -> anyhow::Result<Box<dyn AsyncStream>> {
215            self.seen_addrs.lock().unwrap().push(addr);
216            self.result
217                .lock()
218                .unwrap()
219                .take()
220                .expect("dial should only be called once")
221                .map(|s| Box::new(s) as Box<dyn AsyncStream>)
222        }
223    }
224
225    async fn tcp_pair() -> (TcpStream, TcpStream) {
226        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
227        let addr = listener.local_addr().unwrap();
228
229        let client = TcpStream::connect(addr).await.unwrap();
230        let (server, _) = listener.accept().await.unwrap();
231
232        (server, client)
233    }
234
235    #[tokio::test]
236    async fn handle_http_connect_replies_ok_and_dials_target() {
237        let (server, mut client) = tcp_pair().await;
238        let (target, mut target_peer) = tokio::io::duplex(256);
239        let seen_addrs = Arc::new(Mutex::new(Vec::new()));
240        let dial = MockDial::succeed(target, seen_addrs.clone());
241
242        let proxy = ProxyConnection::new(server, Box::new(dial));
243        let proxy_task = tokio::spawn(async move { proxy.handle_http().await });
244        let target_task = tokio::spawn(async move {
245            let mut buf = Vec::new();
246            target_peer.read_to_end(&mut buf).await.unwrap();
247            buf
248        });
249
250        client
251            .write_all(b"CONNECT example.com:443 HTTP/1.1\r\nHost: example.com:443\r\n\r\n")
252            .await
253            .unwrap();
254
255        let mut response = [0u8; 19];
256        client.read_exact(&mut response).await.unwrap();
257        assert_eq!(&response, b"HTTP/1.1 200 OK\r\n\r\n");
258
259        client.shutdown().await.unwrap();
260        drop(client);
261
262        proxy_task.await.unwrap().unwrap();
263        assert!(target_task.await.unwrap().is_empty());
264        assert_eq!(
265            seen_addrs.lock().unwrap().as_slice(),
266            &[Address::DomainAddress(b"example.com".to_vec(), 443)]
267        );
268    }
269
270    #[tokio::test]
271    async fn handle_http_with_nugget_forwards_request_body_to_target() {
272        let (server, mut client) = tcp_pair().await;
273        let (target, mut target_peer) = tokio::io::duplex(512);
274        let seen_addrs = Arc::new(Mutex::new(Vec::new()));
275        let dial = MockDial::succeed(target, seen_addrs.clone());
276        let raw = b"GET https://upstream.example/path HTTP/1.1\r\nHost: service.internal\r\n\r\n";
277
278        let proxy = ProxyConnection::new(server, Box::new(dial));
279        let proxy_task = tokio::spawn(async move { proxy.handle_http().await });
280        let target_task = tokio::spawn(async move {
281            let mut received = vec![0; raw.len()];
282            target_peer.read_exact(&mut received).await.unwrap();
283            target_peer
284                .write_all(b"HTTP/1.1 204 No Content\r\n\r\n")
285                .await
286                .unwrap();
287            target_peer.shutdown().await.unwrap();
288            received
289        });
290
291        client.write_all(raw).await.unwrap();
292
293        let mut response = vec![0; 27];
294        client.read_exact(&mut response).await.unwrap();
295        assert_eq!(response, b"HTTP/1.1 204 No Content\r\n\r\n");
296
297        client.shutdown().await.unwrap();
298        drop(client);
299
300        proxy_task.await.unwrap().unwrap();
301        assert_eq!(target_task.await.unwrap(), raw);
302        assert_eq!(
303            seen_addrs.lock().unwrap().as_slice(),
304            &[Address::DomainAddress(b"service.internal".to_vec(), 443)]
305        );
306    }
307
308    #[tokio::test]
309    async fn handle_http_bad_request_returns_400_without_dialing() {
310        let (server, mut client) = tcp_pair().await;
311        let seen_addrs = Arc::new(Mutex::new(Vec::new()));
312        let dial = MockDial::fail(anyhow!("dial should not be called"), seen_addrs.clone());
313
314        let proxy = ProxyConnection::new(server, Box::new(dial));
315        let proxy_task = tokio::spawn(async move { proxy.handle_http().await });
316
317        client.write_all(b"BAD\r\n\r\n").await.unwrap();
318
319        let mut response = [0u8; 28];
320        client.read_exact(&mut response).await.unwrap();
321        assert_eq!(&response, b"HTTP/1.1 400 BAD_REQUEST\r\n\r\n");
322
323        client.shutdown().await.unwrap();
324        drop(client);
325
326        assert!(proxy_task.await.unwrap().is_err());
327        assert!(seen_addrs.lock().unwrap().is_empty());
328    }
329
330    #[tokio::test]
331    async fn handle_socks_connect_negotiates_and_replies_succeeded() {
332        let (server, mut client) = tcp_pair().await;
333        let (target, mut target_peer) = tokio::io::duplex(256);
334        let seen_addrs = Arc::new(Mutex::new(Vec::new()));
335        let dial = MockDial::succeed(target, seen_addrs.clone());
336
337        let proxy = ProxyConnection::new(server, Box::new(dial));
338        let proxy_task = tokio::spawn(async move { proxy.handle_socks().await });
339        let target_task = tokio::spawn(async move {
340            let mut buf = Vec::new();
341            target_peer.read_to_end(&mut buf).await.unwrap();
342            buf
343        });
344
345        HandshakeRequest::new(vec![Method::NONE])
346            .write_to(&mut client)
347            .await
348            .unwrap();
349        let handshake = HandshakeResponse::read_from(&mut client).await.unwrap();
350        assert_eq!(handshake.method, Method::NONE);
351
352        SocksRequest::new(
353            Command::Connect,
354            Address::DomainAddress(b"example.com".to_vec(), 1080),
355        )
356        .write_to(&mut client)
357        .await
358        .unwrap();
359
360        let response = Response::read_from(&mut client).await.unwrap();
361        assert_eq!(response.reply, Reply::Succeeded);
362
363        client.shutdown().await.unwrap();
364        drop(client);
365
366        proxy_task.await.unwrap().unwrap();
367        assert!(target_task.await.unwrap().is_empty());
368        assert_eq!(
369            seen_addrs.lock().unwrap().as_slice(),
370            &[Address::DomainAddress(b"example.com".to_vec(), 1080)]
371        );
372    }
373}