1use crate::proxy::{ProxyScheme, UpstreamProxy};
2use anyhow::{Context, Result, bail};
3use std::{net::IpAddr, str};
4use tokio::{
5 io::{AsyncReadExt, AsyncWriteExt},
6 net::{TcpListener, TcpStream},
7 sync::oneshot,
8 task::JoinHandle,
9};
10
11const MAX_HEADER_BYTES: usize = 64 * 1024;
12
13pub struct ProxyBridge {
14 local_url: String,
15 shutdown: Option<oneshot::Sender<()>>,
16 task: JoinHandle<Result<()>>,
17}
18
19impl ProxyBridge {
20 pub async fn start(upstream: UpstreamProxy, listen_port: Option<u16>) -> Result<Self> {
21 let listener = TcpListener::bind(("127.0.0.1", listen_port.unwrap_or(0)))
22 .await
23 .context("failed to bind local bridge listener")?;
24 let local_addr = listener.local_addr()?;
25 let local_url = format!("http://{local_addr}");
26 let (shutdown_tx, shutdown_rx) = oneshot::channel();
27 let task = tokio::spawn(run_server(listener, upstream, shutdown_rx));
28
29 Ok(Self {
30 local_url,
31 shutdown: Some(shutdown_tx),
32 task,
33 })
34 }
35
36 pub fn local_proxy_url(&self) -> String {
37 self.local_url.clone()
38 }
39
40 pub async fn shutdown(mut self) -> Result<()> {
41 if let Some(shutdown) = self.shutdown.take() {
42 let _ = shutdown.send(());
43 }
44
45 self.task
46 .await
47 .context("local proxy bridge task failed to join")?
48 }
49}
50
51async fn run_server(
52 listener: TcpListener,
53 upstream: UpstreamProxy,
54 mut shutdown: oneshot::Receiver<()>,
55) -> Result<()> {
56 loop {
57 tokio::select! {
58 result = listener.accept() => {
59 let (client, peer) = result.context("failed to accept local proxy connection")?;
60 let upstream = upstream.clone();
61 tokio::spawn(async move {
62 if let Err(error) = handle_client(client, upstream).await {
63 tracing::debug!("local proxy connection from {peer} failed: {error:#}");
64 }
65 });
66 }
67 _ = &mut shutdown => {
68 return Ok(());
69 }
70 }
71 }
72}
73
74async fn handle_client(mut client: TcpStream, upstream: UpstreamProxy) -> Result<()> {
75 let request_bytes = read_http_request_head(&mut client).await?;
76 let header_end = find_header_end(&request_bytes).context("HTTP header terminator not found")?;
77 let (head, leftover) = request_bytes.split_at(header_end);
78 let request = parse_http_request(head)?;
79
80 match upstream.scheme() {
81 ProxyScheme::Http => {
82 let mut upstream_stream = TcpStream::connect((upstream.host(), upstream.port()))
83 .await
84 .with_context(|| {
85 format!(
86 "failed to connect upstream HTTP proxy {}",
87 upstream.authority()
88 )
89 })?;
90 let outgoing =
91 add_proxy_authorization(head, upstream.basic_proxy_authorization().as_deref());
92 upstream_stream.write_all(&outgoing).await?;
93 if !leftover.is_empty() {
94 upstream_stream.write_all(leftover).await?;
95 }
96 tokio::io::copy_bidirectional(&mut client, &mut upstream_stream).await?;
97 }
98 ProxyScheme::Socks5 => {
99 if !request.method.eq_ignore_ascii_case("CONNECT") {
100 write_proxy_error(
101 &mut client,
102 501,
103 "Only CONNECT is supported for SOCKS upstreams",
104 )
105 .await?;
106 bail!("non-CONNECT request is not supported for SOCKS upstreams");
107 }
108
109 let (target_host, target_port) = parse_host_port(&request.target)?;
110 let mut upstream_stream =
111 connect_via_socks5(&upstream, &target_host, target_port).await?;
112 client
113 .write_all(b"HTTP/1.1 200 Connection Established\r\n\r\n")
114 .await?;
115 if !leftover.is_empty() {
116 upstream_stream.write_all(leftover).await?;
117 }
118 tokio::io::copy_bidirectional(&mut client, &mut upstream_stream).await?;
119 }
120 }
121
122 Ok(())
123}
124
125async fn read_http_request_head(stream: &mut TcpStream) -> Result<Vec<u8>> {
126 let mut buffer = Vec::with_capacity(4096);
127 let mut chunk = [0_u8; 2048];
128
129 loop {
130 let read = stream.read(&mut chunk).await?;
131 if read == 0 {
132 bail!("connection closed before HTTP header was complete");
133 }
134
135 buffer.extend_from_slice(&chunk[..read]);
136 if find_header_end(&buffer).is_some() {
137 return Ok(buffer);
138 }
139 if buffer.len() > MAX_HEADER_BYTES {
140 bail!("HTTP proxy request header is too large");
141 }
142 }
143}
144
145#[derive(Debug, Eq, PartialEq)]
146struct HttpRequest {
147 method: String,
148 target: String,
149}
150
151fn parse_http_request(head: &[u8]) -> Result<HttpRequest> {
152 let text = str::from_utf8(head).context("HTTP request header is not valid UTF-8")?;
153 let first_line = text.lines().next().context("HTTP request is empty")?;
154 let mut parts = first_line.split_whitespace();
155 let method = parts.next().context("HTTP request is missing method")?;
156 let target = parts.next().context("HTTP request is missing target")?;
157 let version = parts.next().context("HTTP request is missing version")?;
158
159 if !version.starts_with("HTTP/") {
160 bail!("invalid HTTP proxy request version: {version}");
161 }
162
163 Ok(HttpRequest {
164 method: method.to_string(),
165 target: target.to_string(),
166 })
167}
168
169fn find_header_end(buffer: &[u8]) -> Option<usize> {
170 buffer
171 .windows(4)
172 .position(|window| window == b"\r\n\r\n")
173 .map(|index| index + 4)
174}
175
176fn add_proxy_authorization(head: &[u8], authorization: Option<&str>) -> Vec<u8> {
177 let Some(authorization) = authorization else {
178 return head.to_vec();
179 };
180
181 let text = String::from_utf8_lossy(head);
182 if text
183 .to_ascii_lowercase()
184 .contains("\r\nproxy-authorization:")
185 {
186 return head.to_vec();
187 }
188
189 let Some(insert_at) = text.rfind("\r\n\r\n") else {
190 return head.to_vec();
191 };
192
193 let mut outgoing = Vec::with_capacity(head.len() + authorization.len() + 24);
194 outgoing.extend_from_slice(&head[..insert_at]);
195 outgoing.extend_from_slice(format!("\r\nProxy-Authorization: {authorization}").as_bytes());
196 outgoing.extend_from_slice(&head[insert_at..]);
197 outgoing
198}
199
200fn parse_host_port(value: &str) -> Result<(String, u16)> {
201 if let Some(rest) = value.strip_prefix('[') {
202 let (host, tail) = rest
203 .split_once(']')
204 .context("invalid bracketed IPv6 CONNECT target")?;
205 let port = tail
206 .strip_prefix(':')
207 .context("IPv6 CONNECT target is missing port")?
208 .parse()
209 .context("invalid CONNECT target port")?;
210 return Ok((host.to_string(), port));
211 }
212
213 let (host, port) = value
214 .rsplit_once(':')
215 .context("CONNECT target must be host:port")?;
216 if host.is_empty() {
217 bail!("CONNECT target host cannot be empty");
218 }
219
220 Ok((
221 host.to_string(),
222 port.parse().context("invalid CONNECT target port")?,
223 ))
224}
225
226async fn connect_via_socks5(
227 proxy: &UpstreamProxy,
228 target_host: &str,
229 target_port: u16,
230) -> Result<TcpStream> {
231 let mut stream = TcpStream::connect((proxy.host(), proxy.port()))
232 .await
233 .with_context(|| {
234 format!(
235 "failed to connect upstream SOCKS5 proxy {}",
236 proxy.authority()
237 )
238 })?;
239
240 if proxy.has_auth() {
241 stream.write_all(&[0x05, 0x02, 0x00, 0x02]).await?;
242 } else {
243 stream.write_all(&[0x05, 0x01, 0x00]).await?;
244 }
245
246 let mut method_response = [0_u8; 2];
247 stream.read_exact(&mut method_response).await?;
248 if method_response[0] != 0x05 {
249 bail!("invalid SOCKS5 method response");
250 }
251
252 match method_response[1] {
253 0x00 => {}
254 0x02 => authenticate_socks5(proxy, &mut stream).await?,
255 0xff => bail!("SOCKS5 proxy rejected all authentication methods"),
256 method => bail!("SOCKS5 proxy selected unsupported authentication method {method:#x}"),
257 }
258
259 let request = build_socks5_connect_request(target_host, target_port)?;
260 stream.write_all(&request).await?;
261
262 let mut response = [0_u8; 4];
263 stream.read_exact(&mut response).await?;
264 if response[0] != 0x05 {
265 bail!("invalid SOCKS5 connect response");
266 }
267 if response[1] != 0x00 {
268 bail!("SOCKS5 connect failed with code {:#x}", response[1]);
269 }
270
271 read_socks5_bound_address(&mut stream, response[3]).await?;
272 Ok(stream)
273}
274
275async fn authenticate_socks5(proxy: &UpstreamProxy, stream: &mut TcpStream) -> Result<()> {
276 let username = proxy.username().unwrap_or_default().as_bytes();
277 let password = proxy.password().unwrap_or_default().as_bytes();
278 if username.len() > u8::MAX as usize || password.len() > u8::MAX as usize {
279 bail!("SOCKS5 username and password must be at most 255 bytes");
280 }
281
282 let mut request = Vec::with_capacity(username.len() + password.len() + 3);
283 request.push(0x01);
284 request.push(username.len() as u8);
285 request.extend_from_slice(username);
286 request.push(password.len() as u8);
287 request.extend_from_slice(password);
288 stream.write_all(&request).await?;
289
290 let mut response = [0_u8; 2];
291 stream.read_exact(&mut response).await?;
292 if response != [0x01, 0x00] {
293 bail!("SOCKS5 username/password authentication failed");
294 }
295 Ok(())
296}
297
298fn build_socks5_connect_request(target_host: &str, target_port: u16) -> Result<Vec<u8>> {
299 let mut request = vec![0x05, 0x01, 0x00];
300
301 match target_host.parse::<IpAddr>() {
302 Ok(IpAddr::V4(address)) => {
303 request.push(0x01);
304 request.extend_from_slice(&address.octets());
305 }
306 Ok(IpAddr::V6(address)) => {
307 request.push(0x04);
308 request.extend_from_slice(&address.octets());
309 }
310 Err(_) => {
311 let host = target_host.as_bytes();
312 if host.len() > u8::MAX as usize {
313 bail!("SOCKS5 target host is too long");
314 }
315 request.push(0x03);
316 request.push(host.len() as u8);
317 request.extend_from_slice(host);
318 }
319 }
320
321 request.extend_from_slice(&target_port.to_be_bytes());
322 Ok(request)
323}
324
325async fn read_socks5_bound_address(stream: &mut TcpStream, address_type: u8) -> Result<()> {
326 match address_type {
327 0x01 => {
328 let mut buffer = [0_u8; 4 + 2];
329 stream.read_exact(&mut buffer).await?;
330 }
331 0x03 => {
332 let mut length = [0_u8; 1];
333 stream.read_exact(&mut length).await?;
334 let mut buffer = vec![0_u8; length[0] as usize + 2];
335 stream.read_exact(&mut buffer).await?;
336 }
337 0x04 => {
338 let mut buffer = [0_u8; 16 + 2];
339 stream.read_exact(&mut buffer).await?;
340 }
341 other => bail!("invalid SOCKS5 address type {other:#x}"),
342 }
343 Ok(())
344}
345
346async fn write_proxy_error(stream: &mut TcpStream, code: u16, message: &str) -> Result<()> {
347 let response = format!(
348 "HTTP/1.1 {code} {message}\r\nContent-Length: {}\r\nConnection: close\r\n\r\n{message}",
349 message.len()
350 );
351 stream.write_all(response.as_bytes()).await?;
352 Ok(())
353}
354
355#[cfg(test)]
356mod tests {
357 use super::*;
358
359 #[test]
360 fn parses_connect_targets() {
361 assert_eq!(
362 parse_host_port("discord.com:443").unwrap(),
363 ("discord.com".to_string(), 443)
364 );
365 assert_eq!(
366 parse_host_port("[::1]:443").unwrap(),
367 ("::1".to_string(), 443)
368 );
369 }
370
371 #[test]
372 fn injects_proxy_authorization_header() {
373 let head = b"CONNECT discord.com:443 HTTP/1.1\r\nHost: discord.com:443\r\n\r\n";
374
375 let outgoing = add_proxy_authorization(head, Some("Basic abc"));
376 let text = String::from_utf8(outgoing).unwrap();
377
378 assert!(text.contains("\r\nProxy-Authorization: Basic abc\r\n"));
379 assert!(text.ends_with("\r\n\r\n"));
380 }
381
382 #[test]
383 fn does_not_duplicate_proxy_authorization_header() {
384 let head = b"CONNECT discord.com:443 HTTP/1.1\r\nProxy-Authorization: Basic old\r\n\r\n";
385
386 let outgoing = add_proxy_authorization(head, Some("Basic new"));
387
388 assert_eq!(outgoing, head);
389 }
390
391 #[test]
392 fn builds_domain_socks_connect_request() {
393 let request = build_socks5_connect_request("discord.com", 443).unwrap();
394
395 assert_eq!(&request[..5], &[0x05, 0x01, 0x00, 0x03, 11]);
396 assert_eq!(&request[5..16], b"discord.com");
397 assert_eq!(&request[16..], &443_u16.to_be_bytes());
398 }
399}