borer_core/proto/
http_connect.rs1use std::{str::Split, sync::Arc};
2
3use anyhow::{Context, anyhow};
4use log::debug;
5use socks5_proto::Address;
6use tokio::io::{AsyncRead, AsyncReadExt};
7
8const MAX_HTTP_REQUEST_SIZE: usize = 16384;
9const BAD_REQUEST: &str = "BadRequest";
10
11pub struct HttpConnectRequest {
13 addr: Address,
14 mode: RequestMode,
15}
16
17#[derive(Clone, Debug, Eq, PartialEq)]
18enum RequestMode {
19 Connect,
20 Forward(Nugget),
21}
22
23#[derive(Eq, PartialEq, Debug, Clone)]
24pub struct Nugget {
26 data: Arc<Vec<u8>>,
27}
28
29pub async fn read_http_request_end<T: AsyncRead + Unpin>(r: &mut T) -> anyhow::Result<Vec<u8>> {
31 let mut buf = Vec::new();
32 for _i in 0..MAX_HTTP_REQUEST_SIZE {
33 let u1 = r.read_u8().await?;
34 buf.push(u1);
35 if u1 == b'\r' {
36 let [u2, u3, u4] = {
37 let mut x = [0u8; 3];
38 r.read_exact(&mut x).await.map(|_| x)
39 }?;
40 buf.push(u2);
41 buf.push(u3);
42 buf.push(u4);
43 if u2 == b'\n' && u3 == b'\r' && u4 == b'\n' {
44 break;
45 }
46 }
47 }
48 Ok(buf)
49}
50
51impl HttpConnectRequest {
52 pub fn parse(http_request: &[u8]) -> anyhow::Result<Self> {
54 Self::precondition_size(http_request)?;
55 Self::precondition_legal_characters(http_request)?;
56
57 let http_request_as_string =
58 String::from_utf8(http_request.to_vec()).context("contains only ASCII")?;
59
60 let mut lines = http_request_as_string.split("\r\n");
61 let request_line =
62 Self::parse_request_line(lines.next().ok_or_else(|| anyhow!(BAD_REQUEST))?)?;
63
64 let (host, mode) = match request_line.mode {
65 ParsedRequestMode::Connect => (request_line.target.to_string(), RequestMode::Connect),
66 ParsedRequestMode::Forward => (
67 Self::extract_destination_host(&mut lines, request_line.target)
68 .unwrap_or_else(|| request_line.target.to_string()),
69 RequestMode::Forward(Nugget::new(http_request)),
70 ),
71 };
72
73 Ok(Self {
74 addr: Self::host_to_address(host)?,
75 mode,
76 })
77 }
78
79 pub fn addr(&self) -> &Address {
81 &self.addr
82 }
83
84 pub fn nugget(&self) -> Option<&Nugget> {
86 match &self.mode {
87 RequestMode::Connect => None,
88 RequestMode::Forward(nugget) => Some(nugget),
89 }
90 }
91
92 fn host_to_address(host: String) -> anyhow::Result<Address> {
93 let mut parts = host.rsplitn(2, ':');
94 let port = parts
95 .next()
96 .ok_or_else(|| anyhow!("parse http target port failed"))?;
97 let domain = parts
98 .next()
99 .ok_or_else(|| anyhow!("parse http target host failed: {host}"))?;
100
101 if domain.is_empty() {
102 Err(anyhow!("parse http target host failed: {host}"))
103 } else {
104 Ok(Address::DomainAddress(
105 domain.as_bytes().to_vec(),
106 port.parse()?,
107 ))
108 }
109 }
110
111 fn extract_destination_host(lines: &mut Split<&str>, endpoint: &str) -> Option<String> {
112 const HOST_HEADER: &str = "host:";
113
114 lines
115 .find(|line| line.to_ascii_lowercase().starts_with(HOST_HEADER))
116 .map(|line| line[HOST_HEADER.len()..].trim())
117 .map(|host| {
118 let mut host = String::from(host);
119 if host.rfind(':').is_none() {
120 let default_port = if endpoint.to_ascii_lowercase().starts_with("https://") {
121 ":443"
122 } else {
123 ":80"
124 };
125 host.push_str(default_port);
126 }
127 host
128 })
129 }
130
131 fn parse_request_line(request_line: &str) -> anyhow::Result<ParsedRequestLine<'_>> {
132 let request_line_items = request_line.split(' ').collect::<Vec<&str>>();
133 Self::precondition_well_formed(request_line, &request_line_items)?;
134
135 let method = request_line_items[0];
136 let target = request_line_items[1];
137 let version = request_line_items[2];
138
139 let mode = Self::parse_request_mode(method);
140 Self::check_version(version)?;
141
142 Ok(ParsedRequestLine { target, mode })
143 }
144
145 fn precondition_well_formed(
146 request_line: &str,
147 request_line_items: &[&str],
148 ) -> anyhow::Result<()> {
149 if request_line_items.len() != 3 {
150 debug!("bad request line: `{request_line:?}`");
151 Err(anyhow!(BAD_REQUEST))
152 } else {
153 Ok(())
154 }
155 }
156
157 fn check_version(version: &str) -> anyhow::Result<()> {
158 if version != "HTTP/1.1" {
159 debug!("bad version {}", version);
160 Err(anyhow!(BAD_REQUEST))
161 } else {
162 Ok(())
163 }
164 }
165
166 fn parse_request_mode(method: &str) -> ParsedRequestMode {
167 if method == "CONNECT" {
168 ParsedRequestMode::Connect
169 } else {
170 ParsedRequestMode::Forward
171 }
172 }
173
174 fn precondition_legal_characters(http_request: &[u8]) -> anyhow::Result<()> {
175 for b in http_request {
176 match b {
177 32..=126 | 9 | 10 | 13 => {}
179 _ => {
180 debug!("bad request header. Illegal character: {:#04x}", b);
181 return Err(anyhow!(BAD_REQUEST));
182 }
183 }
184 }
185 Ok(())
186 }
187
188 fn precondition_size(http_request: &[u8]) -> anyhow::Result<()> {
189 if http_request.len() >= MAX_HTTP_REQUEST_SIZE {
190 debug!(
191 "bad request header. Size {} exceeds limit {}",
192 http_request.len(),
193 MAX_HTTP_REQUEST_SIZE
194 );
195 Err(anyhow!(BAD_REQUEST))
196 } else {
197 Ok(())
198 }
199 }
200}
201
202struct ParsedRequestLine<'a> {
203 target: &'a str,
204 mode: ParsedRequestMode,
205}
206
207#[derive(Clone, Copy, Debug, Eq, PartialEq)]
208enum ParsedRequestMode {
209 Connect,
210 Forward,
211}
212
213impl Nugget {
214 pub fn new<T: Into<Vec<u8>>>(v: T) -> Self {
216 Self {
217 data: Arc::new(v.into()),
218 }
219 }
220
221 pub fn data(&self) -> Arc<Vec<u8>> {
223 self.data.clone()
224 }
225}
226
227#[cfg(test)]
228mod tests {
229 use std::io::Cursor;
230
231 use socks5_proto::Address;
232
233 use super::{HttpConnectRequest, Nugget, read_http_request_end};
234
235 #[tokio::test]
236 async fn read_http_request_end_reads_until_double_crlf() {
237 let raw = b"CONNECT example.com:443 HTTP/1.1\r\nHost: example.com:443\r\n\r\npayload";
238 let mut cursor = Cursor::new(raw.as_slice());
239
240 let actual = read_http_request_end(&mut cursor).await.unwrap();
241
242 assert_eq!(
243 actual,
244 b"CONNECT example.com:443 HTTP/1.1\r\nHost: example.com:443\r\n\r\n"
245 );
246 }
247
248 #[test]
249 fn parse_connect_request_uses_request_target_and_has_no_nugget() {
250 let raw = b"CONNECT example.com:443 HTTP/1.1\r\nHost: ignored.example.com\r\n\r\n";
251
252 let request = HttpConnectRequest::parse(raw).unwrap();
253
254 assert_eq!(
255 request.addr().clone(),
256 Address::DomainAddress(b"example.com".to_vec(), 443)
257 );
258 assert!(request.nugget().is_none());
259 }
260
261 #[test]
262 fn parse_non_connect_request_uses_host_header_and_preserves_request() {
263 let raw = b"GET https://upstream.example/path HTTP/1.1\r\nHost: service.internal\r\n\r\n";
264
265 let request = HttpConnectRequest::parse(raw).unwrap();
266
267 assert_eq!(
268 request.addr().clone(),
269 Address::DomainAddress(b"service.internal".to_vec(), 443)
270 );
271 assert_eq!(
272 request.nugget().cloned().unwrap(),
273 Nugget::new(raw.as_slice())
274 );
275 }
276
277 #[test]
278 fn parse_non_connect_without_host_port_adds_http_default_port() {
279 let raw = b"GET http://upstream.example/path HTTP/1.1\r\nHost: service.internal\r\n\r\n";
280
281 let request = HttpConnectRequest::parse(raw).unwrap();
282
283 assert_eq!(
284 request.addr().clone(),
285 Address::DomainAddress(b"service.internal".to_vec(), 80)
286 );
287 }
288
289 #[test]
290 fn parse_rejects_non_ascii_bytes() {
291 let raw = b"CONNECT example.com:443 HTTP/1.1\r\nHost: examp\x01e.com\r\n\r\n";
292
293 let err = HttpConnectRequest::parse(raw).err().expect("should reject");
294
295 assert!(err.to_string().contains("BadRequest"));
296 }
297
298 #[test]
299 fn parse_rejects_invalid_http_version() {
300 let raw = b"CONNECT example.com:443 HTTP/1.0\r\nHost: example.com:443\r\n\r\n";
301
302 let err = HttpConnectRequest::parse(raw).err().expect("should reject");
303
304 assert!(err.to_string().contains("BadRequest"));
305 }
306}