pingora_core/protocols/
raw_connect.rs1use super::http::v1::client::HttpSession;
23use super::http::v1::common::*;
24use super::Stream;
25
26use bytes::{BufMut, BytesMut};
27use http::request::Parts as ReqHeader;
28use http::Version;
29use pingora_error::{Error, ErrorType::*, OrErr, Result};
30use pingora_http::ResponseHeader;
31use tokio::io::AsyncWriteExt;
32
33pub async fn connect(stream: Stream, request_header: &ReqHeader) -> Result<(Stream, ProxyDigest)> {
39 let mut http = HttpSession::new(stream);
40
41 let to_wire = http_req_header_to_wire_auth_form(request_header);
43 http.underlying_stream
44 .write_all(to_wire.as_ref())
45 .await
46 .or_err(WriteError, "while writing request headers")?;
47 http.underlying_stream
48 .flush()
49 .await
50 .or_err(WriteError, "while flushing request headers")?;
51
52 let resp_header = http.read_resp_header_parts().await?;
54 Ok((
55 http.underlying_stream,
56 validate_connect_response(resp_header)?,
57 ))
58}
59
60pub fn generate_connect_header<'a, H, S>(
62 host: &str,
63 port: u16,
64 headers: H,
65) -> Result<Box<ReqHeader>>
66where
67 S: AsRef<[u8]>,
68 H: Iterator<Item = (S, &'a Vec<u8>)>,
69{
70 let authority = if host.parse::<std::net::Ipv6Addr>().is_ok() {
73 format!("[{host}]:{port}")
74 } else {
75 format!("{host}:{port}")
76 };
77
78 let req = http::request::Builder::new()
79 .version(http::Version::HTTP_11)
80 .method(http::method::Method::CONNECT)
81 .uri(format!("https://{authority}/")) .header(http::header::HOST, &authority);
83
84 let (mut req, _) = match req.body(()) {
85 Ok(r) => r.into_parts(),
86 Err(e) => {
87 return Err(e).or_err(InvalidHTTPHeader, "Invalid CONNECT request");
88 }
89 };
90
91 for (k, v) in headers {
92 let header_name = http::header::HeaderName::from_bytes(k.as_ref())
93 .or_err(InvalidHTTPHeader, "Invalid CONNECT request")?;
94 let header_value = http::header::HeaderValue::from_bytes(v.as_slice())
95 .or_err(InvalidHTTPHeader, "Invalid CONNECT request")?;
96 req.headers.insert(header_name, header_value);
97 }
98
99 Ok(Box::new(req))
100}
101
102#[derive(Debug)]
104pub struct ProxyDigest {
105 pub response: Box<ResponseHeader>,
107}
108
109impl ProxyDigest {
110 pub fn new(response: Box<ResponseHeader>) -> Self {
111 ProxyDigest { response }
112 }
113}
114
115#[derive(Debug)]
117pub struct ConnectProxyError {
118 pub response: Box<ResponseHeader>,
120}
121
122impl ConnectProxyError {
123 pub fn boxed_new(response: Box<ResponseHeader>) -> Box<Self> {
124 Box::new(ConnectProxyError { response })
125 }
126}
127
128impl std::fmt::Display for ConnectProxyError {
129 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
130 const PROXY_STATUS: &str = "proxy-status";
131
132 let reason = self
133 .response
134 .headers
135 .get(PROXY_STATUS)
136 .and_then(|s| s.to_str().ok())
137 .unwrap_or("missing proxy-status header value");
138 write!(
139 f,
140 "Failed CONNECT Response: status {}, proxy-status {reason}",
141 &self.response.status
142 )
143 }
144}
145
146impl std::error::Error for ConnectProxyError {}
147
148#[inline]
149fn http_req_header_to_wire_auth_form(req: &ReqHeader) -> BytesMut {
150 let mut buf = BytesMut::with_capacity(512);
151
152 let method = req.method.as_str().as_bytes();
154 buf.put_slice(method);
155 buf.put_u8(b' ');
156 if let Some(path) = req.uri.authority() {
158 buf.put_slice(path.as_str().as_bytes());
159 }
160 buf.put_u8(b' ');
161
162 let version = match req.version {
163 Version::HTTP_09 => "HTTP/0.9",
164 Version::HTTP_10 => "HTTP/1.0",
165 Version::HTTP_11 => "HTTP/1.1",
166 _ => "HTTP/0.9",
167 };
168 buf.put_slice(version.as_bytes());
169 buf.put_slice(CRLF);
170
171 let headers = &req.headers;
173 for (key, value) in headers.iter() {
174 buf.put_slice(key.as_ref());
175 buf.put_slice(HEADER_KV_DELIMITER);
176 buf.put_slice(value.as_ref());
177 buf.put_slice(CRLF);
178 }
179
180 buf.put_slice(CRLF);
181 buf
182}
183
184#[inline]
185fn validate_connect_response(resp: Box<ResponseHeader>) -> Result<ProxyDigest> {
186 if !resp.status.is_success() {
187 return Error::e_because(
188 ConnectProxyFailure,
189 "None 2xx code",
190 ConnectProxyError::boxed_new(resp),
191 );
192 }
193
194 if resp.headers.get(http::header::TRANSFER_ENCODING).is_some() {
198 return Error::e_because(
199 ConnectProxyFailure,
200 "Invalid Transfer-Encoding presents",
201 ConnectProxyError::boxed_new(resp),
202 );
203 }
204 Ok(ProxyDigest::new(resp))
205}
206
207#[cfg(test)]
208mod test_sync {
209 use super::*;
210 use std::collections::BTreeMap;
211 use tokio_test::io::Builder;
212
213 #[test]
214 fn test_generate_connect_header() {
215 let mut headers = BTreeMap::new();
216 headers.insert(String::from("foo"), b"bar".to_vec());
217 let req = generate_connect_header("pingora.org", 123, headers.iter()).unwrap();
218
219 assert_eq!(req.method, http::method::Method::CONNECT);
220 assert_eq!(req.uri.authority().unwrap(), "pingora.org:123");
221 assert_eq!(req.headers.get("Host").unwrap(), "pingora.org:123");
222 assert_eq!(req.headers.get("foo").unwrap(), "bar");
223 }
224
225 #[test]
226 fn test_generate_connect_header_ipv6() {
227 let mut headers = BTreeMap::new();
228 headers.insert(String::from("foo"), b"bar".to_vec());
229 let req = generate_connect_header("::1", 123, headers.iter()).unwrap();
230
231 assert_eq!(req.method, http::method::Method::CONNECT);
232 assert_eq!(req.uri.authority().unwrap(), "[::1]:123");
233 assert_eq!(req.headers.get("Host").unwrap(), "[::1]:123");
234 assert_eq!(req.headers.get("foo").unwrap(), "bar");
235 }
236
237 #[test]
238 fn test_request_to_wire_auth_form() {
239 let new_request = http::Request::builder()
240 .method("CONNECT")
241 .uri("https://pingora.org:123/")
242 .header("Foo", "Bar")
243 .body(())
244 .unwrap();
245 let (new_request, _) = new_request.into_parts();
246 let wire = http_req_header_to_wire_auth_form(&new_request);
247 assert_eq!(
248 &b"CONNECT pingora.org:123 HTTP/1.1\r\nfoo: Bar\r\n\r\n"[..],
249 &wire
250 );
251 }
252
253 #[test]
254 fn test_validate_connect_response() {
255 let resp = ResponseHeader::build(200, None).unwrap();
256 validate_connect_response(Box::new(resp)).unwrap();
257
258 let resp = ResponseHeader::build(404, None).unwrap();
259 assert!(validate_connect_response(Box::new(resp)).is_err());
260
261 let mut resp = ResponseHeader::build(200, None).unwrap();
262 resp.append_header("content-length", 0).unwrap();
263 assert!(validate_connect_response(Box::new(resp)).is_ok());
264
265 let mut resp = ResponseHeader::build(200, None).unwrap();
266 resp.append_header("transfer-encoding", 0).unwrap();
267 assert!(validate_connect_response(Box::new(resp)).is_err());
268 }
269
270 #[tokio::test]
271 async fn test_connect_write_request() {
272 let wire = b"CONNECT pingora.org:123 HTTP/1.1\r\nhost: pingora.org:123\r\n\r\n";
273 let mock_io = Box::new(Builder::new().write(wire).build());
274
275 let headers: BTreeMap<String, Vec<u8>> = BTreeMap::new();
276 let req = generate_connect_header("pingora.org", 123, headers.iter()).unwrap();
277 assert!(connect(mock_io, &req).await.is_err());
279
280 let to_wire = b"CONNECT pingora.org:123 HTTP/1.1\r\nhost: pingora.org:123\r\n\r\n";
281 let from_wire = b"HTTP/1.1 200 OK\r\n\r\n";
282 let mock_io = Box::new(Builder::new().write(to_wire).read(from_wire).build());
283
284 let req = generate_connect_header("pingora.org", 123, headers.iter()).unwrap();
285 let result = connect(mock_io, &req).await;
286 assert!(result.is_ok());
287 }
288}