Skip to main content

discord_proxy/
bridge.rs

1use crate::proxy::{ProxyScheme, UpstreamProxy};
2use anyhow::{Context, Result, bail};
3use std::{net::IpAddr, str};
4use tokio::{
5    io::{AsyncReadExt, AsyncWriteExt},
6    net::{TcpListener, TcpStream},
7    sync::oneshot,
8    task::JoinHandle,
9};
10
11const MAX_HEADER_BYTES: usize = 64 * 1024;
12
13pub struct ProxyBridge {
14    local_url: String,
15    shutdown: Option<oneshot::Sender<()>>,
16    task: JoinHandle<Result<()>>,
17}
18
19impl ProxyBridge {
20    pub async fn start(upstream: UpstreamProxy, listen_port: Option<u16>) -> Result<Self> {
21        let listener = TcpListener::bind(("127.0.0.1", listen_port.unwrap_or(0)))
22            .await
23            .context("failed to bind local bridge listener")?;
24        let local_addr = listener.local_addr()?;
25        let local_url = format!("http://{local_addr}");
26        let (shutdown_tx, shutdown_rx) = oneshot::channel();
27        let task = tokio::spawn(run_server(listener, upstream, shutdown_rx));
28
29        Ok(Self {
30            local_url,
31            shutdown: Some(shutdown_tx),
32            task,
33        })
34    }
35
36    pub fn local_proxy_url(&self) -> String {
37        self.local_url.clone()
38    }
39
40    pub async fn shutdown(mut self) -> Result<()> {
41        if let Some(shutdown) = self.shutdown.take() {
42            let _ = shutdown.send(());
43        }
44
45        self.task
46            .await
47            .context("local proxy bridge task failed to join")?
48    }
49}
50
51async fn run_server(
52    listener: TcpListener,
53    upstream: UpstreamProxy,
54    mut shutdown: oneshot::Receiver<()>,
55) -> Result<()> {
56    loop {
57        tokio::select! {
58            result = listener.accept() => {
59                let (client, peer) = result.context("failed to accept local proxy connection")?;
60                let upstream = upstream.clone();
61                tokio::spawn(async move {
62                    if let Err(error) = handle_client(client, upstream).await {
63                        tracing::debug!("local proxy connection from {peer} failed: {error:#}");
64                    }
65                });
66            }
67            _ = &mut shutdown => {
68                return Ok(());
69            }
70        }
71    }
72}
73
74async fn handle_client(mut client: TcpStream, upstream: UpstreamProxy) -> Result<()> {
75    let request_bytes = read_http_request_head(&mut client).await?;
76    let header_end = find_header_end(&request_bytes).context("HTTP header terminator not found")?;
77    let (head, leftover) = request_bytes.split_at(header_end);
78    let request = parse_http_request(head)?;
79
80    match upstream.scheme() {
81        ProxyScheme::Http => {
82            let mut upstream_stream = TcpStream::connect((upstream.host(), upstream.port()))
83                .await
84                .with_context(|| {
85                    format!(
86                        "failed to connect upstream HTTP proxy {}",
87                        upstream.authority()
88                    )
89                })?;
90            let outgoing =
91                add_proxy_authorization(head, upstream.basic_proxy_authorization().as_deref());
92            upstream_stream.write_all(&outgoing).await?;
93            if !leftover.is_empty() {
94                upstream_stream.write_all(leftover).await?;
95            }
96            tokio::io::copy_bidirectional(&mut client, &mut upstream_stream).await?;
97        }
98        ProxyScheme::Socks5 => {
99            if !request.method.eq_ignore_ascii_case("CONNECT") {
100                write_proxy_error(
101                    &mut client,
102                    501,
103                    "Only CONNECT is supported for SOCKS upstreams",
104                )
105                .await?;
106                bail!("non-CONNECT request is not supported for SOCKS upstreams");
107            }
108
109            let (target_host, target_port) = parse_host_port(&request.target)?;
110            let mut upstream_stream =
111                connect_via_socks5(&upstream, &target_host, target_port).await?;
112            client
113                .write_all(b"HTTP/1.1 200 Connection Established\r\n\r\n")
114                .await?;
115            if !leftover.is_empty() {
116                upstream_stream.write_all(leftover).await?;
117            }
118            tokio::io::copy_bidirectional(&mut client, &mut upstream_stream).await?;
119        }
120    }
121
122    Ok(())
123}
124
125async fn read_http_request_head(stream: &mut TcpStream) -> Result<Vec<u8>> {
126    let mut buffer = Vec::with_capacity(4096);
127    let mut chunk = [0_u8; 2048];
128
129    loop {
130        let read = stream.read(&mut chunk).await?;
131        if read == 0 {
132            bail!("connection closed before HTTP header was complete");
133        }
134
135        buffer.extend_from_slice(&chunk[..read]);
136        if find_header_end(&buffer).is_some() {
137            return Ok(buffer);
138        }
139        if buffer.len() > MAX_HEADER_BYTES {
140            bail!("HTTP proxy request header is too large");
141        }
142    }
143}
144
145#[derive(Debug, Eq, PartialEq)]
146struct HttpRequest {
147    method: String,
148    target: String,
149}
150
151fn parse_http_request(head: &[u8]) -> Result<HttpRequest> {
152    let text = str::from_utf8(head).context("HTTP request header is not valid UTF-8")?;
153    let first_line = text.lines().next().context("HTTP request is empty")?;
154    let mut parts = first_line.split_whitespace();
155    let method = parts.next().context("HTTP request is missing method")?;
156    let target = parts.next().context("HTTP request is missing target")?;
157    let version = parts.next().context("HTTP request is missing version")?;
158
159    if !version.starts_with("HTTP/") {
160        bail!("invalid HTTP proxy request version: {version}");
161    }
162
163    Ok(HttpRequest {
164        method: method.to_string(),
165        target: target.to_string(),
166    })
167}
168
169fn find_header_end(buffer: &[u8]) -> Option<usize> {
170    buffer
171        .windows(4)
172        .position(|window| window == b"\r\n\r\n")
173        .map(|index| index + 4)
174}
175
176fn add_proxy_authorization(head: &[u8], authorization: Option<&str>) -> Vec<u8> {
177    let Some(authorization) = authorization else {
178        return head.to_vec();
179    };
180
181    let text = String::from_utf8_lossy(head);
182    if text
183        .to_ascii_lowercase()
184        .contains("\r\nproxy-authorization:")
185    {
186        return head.to_vec();
187    }
188
189    let Some(insert_at) = text.rfind("\r\n\r\n") else {
190        return head.to_vec();
191    };
192
193    let mut outgoing = Vec::with_capacity(head.len() + authorization.len() + 24);
194    outgoing.extend_from_slice(&head[..insert_at]);
195    outgoing.extend_from_slice(format!("\r\nProxy-Authorization: {authorization}").as_bytes());
196    outgoing.extend_from_slice(&head[insert_at..]);
197    outgoing
198}
199
200fn parse_host_port(value: &str) -> Result<(String, u16)> {
201    if let Some(rest) = value.strip_prefix('[') {
202        let (host, tail) = rest
203            .split_once(']')
204            .context("invalid bracketed IPv6 CONNECT target")?;
205        let port = tail
206            .strip_prefix(':')
207            .context("IPv6 CONNECT target is missing port")?
208            .parse()
209            .context("invalid CONNECT target port")?;
210        return Ok((host.to_string(), port));
211    }
212
213    let (host, port) = value
214        .rsplit_once(':')
215        .context("CONNECT target must be host:port")?;
216    if host.is_empty() {
217        bail!("CONNECT target host cannot be empty");
218    }
219
220    Ok((
221        host.to_string(),
222        port.parse().context("invalid CONNECT target port")?,
223    ))
224}
225
226async fn connect_via_socks5(
227    proxy: &UpstreamProxy,
228    target_host: &str,
229    target_port: u16,
230) -> Result<TcpStream> {
231    let mut stream = TcpStream::connect((proxy.host(), proxy.port()))
232        .await
233        .with_context(|| {
234            format!(
235                "failed to connect upstream SOCKS5 proxy {}",
236                proxy.authority()
237            )
238        })?;
239
240    if proxy.has_auth() {
241        stream.write_all(&[0x05, 0x02, 0x00, 0x02]).await?;
242    } else {
243        stream.write_all(&[0x05, 0x01, 0x00]).await?;
244    }
245
246    let mut method_response = [0_u8; 2];
247    stream.read_exact(&mut method_response).await?;
248    if method_response[0] != 0x05 {
249        bail!("invalid SOCKS5 method response");
250    }
251
252    match method_response[1] {
253        0x00 => {}
254        0x02 => authenticate_socks5(proxy, &mut stream).await?,
255        0xff => bail!("SOCKS5 proxy rejected all authentication methods"),
256        method => bail!("SOCKS5 proxy selected unsupported authentication method {method:#x}"),
257    }
258
259    let request = build_socks5_connect_request(target_host, target_port)?;
260    stream.write_all(&request).await?;
261
262    let mut response = [0_u8; 4];
263    stream.read_exact(&mut response).await?;
264    if response[0] != 0x05 {
265        bail!("invalid SOCKS5 connect response");
266    }
267    if response[1] != 0x00 {
268        bail!("SOCKS5 connect failed with code {:#x}", response[1]);
269    }
270
271    read_socks5_bound_address(&mut stream, response[3]).await?;
272    Ok(stream)
273}
274
275async fn authenticate_socks5(proxy: &UpstreamProxy, stream: &mut TcpStream) -> Result<()> {
276    let username = proxy.username().unwrap_or_default().as_bytes();
277    let password = proxy.password().unwrap_or_default().as_bytes();
278    if username.len() > u8::MAX as usize || password.len() > u8::MAX as usize {
279        bail!("SOCKS5 username and password must be at most 255 bytes");
280    }
281
282    let mut request = Vec::with_capacity(username.len() + password.len() + 3);
283    request.push(0x01);
284    request.push(username.len() as u8);
285    request.extend_from_slice(username);
286    request.push(password.len() as u8);
287    request.extend_from_slice(password);
288    stream.write_all(&request).await?;
289
290    let mut response = [0_u8; 2];
291    stream.read_exact(&mut response).await?;
292    if response != [0x01, 0x00] {
293        bail!("SOCKS5 username/password authentication failed");
294    }
295    Ok(())
296}
297
298fn build_socks5_connect_request(target_host: &str, target_port: u16) -> Result<Vec<u8>> {
299    let mut request = vec![0x05, 0x01, 0x00];
300
301    match target_host.parse::<IpAddr>() {
302        Ok(IpAddr::V4(address)) => {
303            request.push(0x01);
304            request.extend_from_slice(&address.octets());
305        }
306        Ok(IpAddr::V6(address)) => {
307            request.push(0x04);
308            request.extend_from_slice(&address.octets());
309        }
310        Err(_) => {
311            let host = target_host.as_bytes();
312            if host.len() > u8::MAX as usize {
313                bail!("SOCKS5 target host is too long");
314            }
315            request.push(0x03);
316            request.push(host.len() as u8);
317            request.extend_from_slice(host);
318        }
319    }
320
321    request.extend_from_slice(&target_port.to_be_bytes());
322    Ok(request)
323}
324
325async fn read_socks5_bound_address(stream: &mut TcpStream, address_type: u8) -> Result<()> {
326    match address_type {
327        0x01 => {
328            let mut buffer = [0_u8; 4 + 2];
329            stream.read_exact(&mut buffer).await?;
330        }
331        0x03 => {
332            let mut length = [0_u8; 1];
333            stream.read_exact(&mut length).await?;
334            let mut buffer = vec![0_u8; length[0] as usize + 2];
335            stream.read_exact(&mut buffer).await?;
336        }
337        0x04 => {
338            let mut buffer = [0_u8; 16 + 2];
339            stream.read_exact(&mut buffer).await?;
340        }
341        other => bail!("invalid SOCKS5 address type {other:#x}"),
342    }
343    Ok(())
344}
345
346async fn write_proxy_error(stream: &mut TcpStream, code: u16, message: &str) -> Result<()> {
347    let response = format!(
348        "HTTP/1.1 {code} {message}\r\nContent-Length: {}\r\nConnection: close\r\n\r\n{message}",
349        message.len()
350    );
351    stream.write_all(response.as_bytes()).await?;
352    Ok(())
353}
354
355#[cfg(test)]
356mod tests {
357    use super::*;
358
359    #[test]
360    fn parses_connect_targets() {
361        assert_eq!(
362            parse_host_port("discord.com:443").unwrap(),
363            ("discord.com".to_string(), 443)
364        );
365        assert_eq!(
366            parse_host_port("[::1]:443").unwrap(),
367            ("::1".to_string(), 443)
368        );
369    }
370
371    #[test]
372    fn injects_proxy_authorization_header() {
373        let head = b"CONNECT discord.com:443 HTTP/1.1\r\nHost: discord.com:443\r\n\r\n";
374
375        let outgoing = add_proxy_authorization(head, Some("Basic abc"));
376        let text = String::from_utf8(outgoing).unwrap();
377
378        assert!(text.contains("\r\nProxy-Authorization: Basic abc\r\n"));
379        assert!(text.ends_with("\r\n\r\n"));
380    }
381
382    #[test]
383    fn does_not_duplicate_proxy_authorization_header() {
384        let head = b"CONNECT discord.com:443 HTTP/1.1\r\nProxy-Authorization: Basic old\r\n\r\n";
385
386        let outgoing = add_proxy_authorization(head, Some("Basic new"));
387
388        assert_eq!(outgoing, head);
389    }
390
391    #[test]
392    fn builds_domain_socks_connect_request() {
393        let request = build_socks5_connect_request("discord.com", 443).unwrap();
394
395        assert_eq!(&request[..5], &[0x05, 0x01, 0x00, 0x03, 11]);
396        assert_eq!(&request[5..16], b"discord.com");
397        assert_eq!(&request[16..], &443_u16.to_be_bytes());
398    }
399}