Skip to main content

borer_core/proto/
http_connect.rs

1use std::{str::Split, sync::Arc};
2
3use anyhow::{Context, anyhow};
4use log::debug;
5use socks5_proto::Address;
6use tokio::io::{AsyncRead, AsyncReadExt};
7
8const MAX_HTTP_REQUEST_SIZE: usize = 16384;
9const BAD_REQUEST: &str = "BadRequest";
10
11/// Parsed CONNECT-style or proxied HTTP request metadata.
12pub struct HttpConnectRequest {
13    addr: Address,
14    mode: RequestMode,
15}
16
17#[derive(Clone, Debug, Eq, PartialEq)]
18enum RequestMode {
19    Connect,
20    Forward(Nugget),
21}
22
23#[derive(Eq, PartialEq, Debug, Clone)]
24/// Preserved raw HTTP request bytes that should be forwarded upstream.
25pub struct Nugget {
26    data: Arc<Vec<u8>>,
27}
28
29/// Read bytes until the end of the HTTP header block (`\r\n\r\n`).
30pub async fn read_http_request_end<T: AsyncRead + Unpin>(r: &mut T) -> anyhow::Result<Vec<u8>> {
31    let mut buf = Vec::new();
32    for _i in 0..MAX_HTTP_REQUEST_SIZE {
33        let u1 = r.read_u8().await?;
34        buf.push(u1);
35        if u1 == b'\r' {
36            let [u2, u3, u4] = {
37                let mut x = [0u8; 3];
38                r.read_exact(&mut x).await.map(|_| x)
39            }?;
40            buf.push(u2);
41            buf.push(u3);
42            buf.push(u4);
43            if u2 == b'\n' && u3 == b'\r' && u4 == b'\n' {
44                break;
45            }
46        }
47    }
48    Ok(buf)
49}
50
51impl HttpConnectRequest {
52    /// Parse an HTTP request into a target address and optional forwardable payload.
53    pub fn parse(http_request: &[u8]) -> anyhow::Result<Self> {
54        Self::precondition_size(http_request)?;
55        Self::precondition_legal_characters(http_request)?;
56
57        let http_request_as_string =
58            String::from_utf8(http_request.to_vec()).context("contains only ASCII")?;
59
60        let mut lines = http_request_as_string.split("\r\n");
61        let request_line =
62            Self::parse_request_line(lines.next().ok_or_else(|| anyhow!(BAD_REQUEST))?)?;
63
64        let (host, mode) = match request_line.mode {
65            ParsedRequestMode::Connect => (request_line.target.to_string(), RequestMode::Connect),
66            ParsedRequestMode::Forward => (
67                Self::extract_destination_host(&mut lines, request_line.target)
68                    .unwrap_or_else(|| request_line.target.to_string()),
69                RequestMode::Forward(Nugget::new(http_request)),
70            ),
71        };
72
73        Ok(Self {
74            addr: Self::host_to_address(host)?,
75            mode,
76        })
77    }
78
79    /// Return the upstream address extracted from this request.
80    pub fn addr(&self) -> &Address {
81        &self.addr
82    }
83
84    /// Return the preserved raw request when this is a forward-proxy request.
85    pub fn nugget(&self) -> Option<&Nugget> {
86        match &self.mode {
87            RequestMode::Connect => None,
88            RequestMode::Forward(nugget) => Some(nugget),
89        }
90    }
91
92    fn host_to_address(host: String) -> anyhow::Result<Address> {
93        let mut parts = host.rsplitn(2, ':');
94        let port = parts
95            .next()
96            .ok_or_else(|| anyhow!("parse http target port failed"))?;
97        let domain = parts
98            .next()
99            .ok_or_else(|| anyhow!("parse http target host failed: {host}"))?;
100
101        if domain.is_empty() {
102            Err(anyhow!("parse http target host failed: {host}"))
103        } else {
104            Ok(Address::DomainAddress(
105                domain.as_bytes().to_vec(),
106                port.parse()?,
107            ))
108        }
109    }
110
111    fn extract_destination_host(lines: &mut Split<&str>, endpoint: &str) -> Option<String> {
112        const HOST_HEADER: &str = "host:";
113
114        lines
115            .find(|line| line.to_ascii_lowercase().starts_with(HOST_HEADER))
116            .map(|line| line[HOST_HEADER.len()..].trim())
117            .map(|host| {
118                let mut host = String::from(host);
119                if host.rfind(':').is_none() {
120                    let default_port = if endpoint.to_ascii_lowercase().starts_with("https://") {
121                        ":443"
122                    } else {
123                        ":80"
124                    };
125                    host.push_str(default_port);
126                }
127                host
128            })
129    }
130
131    fn parse_request_line(request_line: &str) -> anyhow::Result<ParsedRequestLine<'_>> {
132        let request_line_items = request_line.split(' ').collect::<Vec<&str>>();
133        Self::precondition_well_formed(request_line, &request_line_items)?;
134
135        let method = request_line_items[0];
136        let target = request_line_items[1];
137        let version = request_line_items[2];
138
139        let mode = Self::parse_request_mode(method);
140        Self::check_version(version)?;
141
142        Ok(ParsedRequestLine { target, mode })
143    }
144
145    fn precondition_well_formed(
146        request_line: &str,
147        request_line_items: &[&str],
148    ) -> anyhow::Result<()> {
149        if request_line_items.len() != 3 {
150            debug!("bad request line: `{request_line:?}`");
151            Err(anyhow!(BAD_REQUEST))
152        } else {
153            Ok(())
154        }
155    }
156
157    fn check_version(version: &str) -> anyhow::Result<()> {
158        if version != "HTTP/1.1" {
159            debug!("bad version {}", version);
160            Err(anyhow!(BAD_REQUEST))
161        } else {
162            Ok(())
163        }
164    }
165
166    fn parse_request_mode(method: &str) -> ParsedRequestMode {
167        if method == "CONNECT" {
168            ParsedRequestMode::Connect
169        } else {
170            ParsedRequestMode::Forward
171        }
172    }
173
174    fn precondition_legal_characters(http_request: &[u8]) -> anyhow::Result<()> {
175        for b in http_request {
176            match b {
177                // non-ascii characters don't make sense in this context
178                32..=126 | 9 | 10 | 13 => {}
179                _ => {
180                    debug!("bad request header. Illegal character: {:#04x}", b);
181                    return Err(anyhow!(BAD_REQUEST));
182                }
183            }
184        }
185        Ok(())
186    }
187
188    fn precondition_size(http_request: &[u8]) -> anyhow::Result<()> {
189        if http_request.len() >= MAX_HTTP_REQUEST_SIZE {
190            debug!(
191                "bad request header. Size {} exceeds limit {}",
192                http_request.len(),
193                MAX_HTTP_REQUEST_SIZE
194            );
195            Err(anyhow!(BAD_REQUEST))
196        } else {
197            Ok(())
198        }
199    }
200}
201
202struct ParsedRequestLine<'a> {
203    target: &'a str,
204    mode: ParsedRequestMode,
205}
206
207#[derive(Clone, Copy, Debug, Eq, PartialEq)]
208enum ParsedRequestMode {
209    Connect,
210    Forward,
211}
212
213impl Nugget {
214    /// Store an owned copy of the raw request bytes.
215    pub fn new<T: Into<Vec<u8>>>(v: T) -> Self {
216        Self {
217            data: Arc::new(v.into()),
218        }
219    }
220
221    /// Access the raw request bytes.
222    pub fn data(&self) -> Arc<Vec<u8>> {
223        self.data.clone()
224    }
225}
226
227#[cfg(test)]
228mod tests {
229    use std::io::Cursor;
230
231    use socks5_proto::Address;
232
233    use super::{HttpConnectRequest, Nugget, read_http_request_end};
234
235    #[tokio::test]
236    async fn read_http_request_end_reads_until_double_crlf() {
237        let raw = b"CONNECT example.com:443 HTTP/1.1\r\nHost: example.com:443\r\n\r\npayload";
238        let mut cursor = Cursor::new(raw.as_slice());
239
240        let actual = read_http_request_end(&mut cursor).await.unwrap();
241
242        assert_eq!(
243            actual,
244            b"CONNECT example.com:443 HTTP/1.1\r\nHost: example.com:443\r\n\r\n"
245        );
246    }
247
248    #[test]
249    fn parse_connect_request_uses_request_target_and_has_no_nugget() {
250        let raw = b"CONNECT example.com:443 HTTP/1.1\r\nHost: ignored.example.com\r\n\r\n";
251
252        let request = HttpConnectRequest::parse(raw).unwrap();
253
254        assert_eq!(
255            request.addr().clone(),
256            Address::DomainAddress(b"example.com".to_vec(), 443)
257        );
258        assert!(request.nugget().is_none());
259    }
260
261    #[test]
262    fn parse_non_connect_request_uses_host_header_and_preserves_request() {
263        let raw = b"GET https://upstream.example/path HTTP/1.1\r\nHost: service.internal\r\n\r\n";
264
265        let request = HttpConnectRequest::parse(raw).unwrap();
266
267        assert_eq!(
268            request.addr().clone(),
269            Address::DomainAddress(b"service.internal".to_vec(), 443)
270        );
271        assert_eq!(
272            request.nugget().cloned().unwrap(),
273            Nugget::new(raw.as_slice())
274        );
275    }
276
277    #[test]
278    fn parse_non_connect_without_host_port_adds_http_default_port() {
279        let raw = b"GET http://upstream.example/path HTTP/1.1\r\nHost: service.internal\r\n\r\n";
280
281        let request = HttpConnectRequest::parse(raw).unwrap();
282
283        assert_eq!(
284            request.addr().clone(),
285            Address::DomainAddress(b"service.internal".to_vec(), 80)
286        );
287    }
288
289    #[test]
290    fn parse_rejects_non_ascii_bytes() {
291        let raw = b"CONNECT example.com:443 HTTP/1.1\r\nHost: examp\x01e.com\r\n\r\n";
292
293        let err = HttpConnectRequest::parse(raw).err().expect("should reject");
294
295        assert!(err.to_string().contains("BadRequest"));
296    }
297
298    #[test]
299    fn parse_rejects_invalid_http_version() {
300        let raw = b"CONNECT example.com:443 HTTP/1.0\r\nHost: example.com:443\r\n\r\n";
301
302        let err = HttpConnectRequest::parse(raw).err().expect("should reject");
303
304        assert!(err.to_string().contains("BadRequest"));
305    }
306}