1use std::any::Any;
23
24use super::http::v1::client::HttpSession;
25use super::http::v1::common::*;
26use super::Stream;
27
28use bytes::{BufMut, BytesMut};
29use http::request::Parts as ReqHeader;
30use http::Version;
31use pingora_error::{Error, ErrorType::*, OrErr, Result};
32use pingora_http::ResponseHeader;
33use tokio::io::AsyncWriteExt;
34
35pub async fn connect<P>(
41 stream: Stream,
42 request_header: &ReqHeader,
43 peer: &P,
44) -> Result<(Stream, ProxyDigest)>
45where
46 P: crate::upstreams::peer::Peer,
47{
48 let mut http = HttpSession::new(stream);
49
50 let to_wire = http_req_header_to_wire_auth_form(request_header);
52 http.underlying_stream
53 .write_all(to_wire.as_ref())
54 .await
55 .or_err(WriteError, "while writing request headers")?;
56 http.underlying_stream
57 .flush()
58 .await
59 .or_err(WriteError, "while flushing request headers")?;
60
61 let resp_header = http.read_resp_header_parts().await?;
63 Ok((
64 http.underlying_stream,
65 validate_connect_response(resp_header, peer, request_header)?,
66 ))
67}
68
69pub fn generate_connect_header<'a, H, S>(
71 host: &str,
72 port: u16,
73 headers: H,
74) -> Result<Box<ReqHeader>>
75where
76 S: AsRef<[u8]>,
77 H: Iterator<Item = (S, &'a Vec<u8>)>,
78{
79 let authority = if host.parse::<std::net::Ipv6Addr>().is_ok() {
82 format!("[{host}]:{port}")
83 } else {
84 format!("{host}:{port}")
85 };
86
87 let req = http::request::Builder::new()
88 .version(http::Version::HTTP_11)
89 .method(http::method::Method::CONNECT)
90 .uri(format!("https://{authority}/")) .header(http::header::HOST, &authority);
92
93 let (mut req, _) = match req.body(()) {
94 Ok(r) => r.into_parts(),
95 Err(e) => {
96 return Err(e).or_err(InvalidHTTPHeader, "Invalid CONNECT request");
97 }
98 };
99
100 for (k, v) in headers {
101 let header_name = http::header::HeaderName::from_bytes(k.as_ref())
102 .or_err(InvalidHTTPHeader, "Invalid CONNECT request")?;
103 let header_value = http::header::HeaderValue::from_bytes(v.as_slice())
104 .or_err(InvalidHTTPHeader, "Invalid CONNECT request")?;
105 req.headers.insert(header_name, header_value);
106 }
107
108 Ok(Box::new(req))
109}
110
111#[derive(Debug)]
113pub struct ProxyDigest {
114 pub response: Box<ResponseHeader>,
116 pub user_data: Option<Box<dyn Any + Send + Sync>>,
118}
119
120impl ProxyDigest {
121 pub fn new(
122 response: Box<ResponseHeader>,
123 user_data: Option<Box<dyn Any + Send + Sync>>,
124 ) -> Self {
125 ProxyDigest {
126 response,
127 user_data,
128 }
129 }
130}
131
132#[derive(Debug)]
134pub struct ConnectProxyError {
135 pub response: Box<ResponseHeader>,
137}
138
139impl ConnectProxyError {
140 pub fn boxed_new(response: Box<ResponseHeader>) -> Box<Self> {
141 Box::new(ConnectProxyError { response })
142 }
143}
144
145impl std::fmt::Display for ConnectProxyError {
146 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
147 const PROXY_STATUS: &str = "proxy-status";
148
149 let reason = self
150 .response
151 .headers
152 .get(PROXY_STATUS)
153 .and_then(|s| s.to_str().ok())
154 .unwrap_or("missing proxy-status header value");
155 write!(
156 f,
157 "Failed CONNECT Response: status {}, proxy-status {reason}",
158 &self.response.status
159 )
160 }
161}
162
163impl std::error::Error for ConnectProxyError {}
164
165#[inline]
166fn http_req_header_to_wire_auth_form(req: &ReqHeader) -> BytesMut {
167 let mut buf = BytesMut::with_capacity(512);
168
169 let method = req.method.as_str().as_bytes();
171 buf.put_slice(method);
172 buf.put_u8(b' ');
173 if let Some(path) = req.uri.authority() {
175 buf.put_slice(path.as_str().as_bytes());
176 }
177 buf.put_u8(b' ');
178
179 let version = match req.version {
180 Version::HTTP_09 => "HTTP/0.9",
181 Version::HTTP_10 => "HTTP/1.0",
182 Version::HTTP_11 => "HTTP/1.1",
183 _ => "HTTP/0.9",
184 };
185 buf.put_slice(version.as_bytes());
186 buf.put_slice(CRLF);
187
188 let headers = &req.headers;
190 for (key, value) in headers.iter() {
191 buf.put_slice(key.as_ref());
192 buf.put_slice(HEADER_KV_DELIMITER);
193 buf.put_slice(value.as_ref());
194 buf.put_slice(CRLF);
195 }
196
197 buf.put_slice(CRLF);
198 buf
199}
200
201#[inline]
202fn validate_connect_response<P>(
203 resp: Box<ResponseHeader>,
204 peer: &P,
205 req: &ReqHeader,
206) -> Result<ProxyDigest>
207where
208 P: crate::upstreams::peer::Peer,
209{
210 if !resp.status.is_success() {
211 return Error::e_because(
212 ConnectProxyFailure,
213 "None 2xx code",
214 ConnectProxyError::boxed_new(resp),
215 );
216 }
217
218 if resp.headers.get(http::header::TRANSFER_ENCODING).is_some() {
222 return Error::e_because(
223 ConnectProxyFailure,
224 "Invalid Transfer-Encoding presents",
225 ConnectProxyError::boxed_new(resp),
226 );
227 }
228
229 let user_data = peer
230 .proxy_digest_user_data_hook()
231 .and_then(|hook| hook(req, &resp));
232 Ok(ProxyDigest::new(resp, user_data))
233}
234
235#[cfg(test)]
236mod test_sync {
237 use super::*;
238 use std::collections::BTreeMap;
239 use tokio_test::io::Builder;
240
241 #[test]
242 fn test_generate_connect_header() {
243 let mut headers = BTreeMap::new();
244 headers.insert(String::from("foo"), b"bar".to_vec());
245 let req = generate_connect_header("pingora.org", 123, headers.iter()).unwrap();
246
247 assert_eq!(req.method, http::method::Method::CONNECT);
248 assert_eq!(req.uri.authority().unwrap(), "pingora.org:123");
249 assert_eq!(req.headers.get("Host").unwrap(), "pingora.org:123");
250 assert_eq!(req.headers.get("foo").unwrap(), "bar");
251 }
252
253 #[test]
254 fn test_generate_connect_header_ipv6() {
255 let mut headers = BTreeMap::new();
256 headers.insert(String::from("foo"), b"bar".to_vec());
257 let req = generate_connect_header("::1", 123, headers.iter()).unwrap();
258
259 assert_eq!(req.method, http::method::Method::CONNECT);
260 assert_eq!(req.uri.authority().unwrap(), "[::1]:123");
261 assert_eq!(req.headers.get("Host").unwrap(), "[::1]:123");
262 assert_eq!(req.headers.get("foo").unwrap(), "bar");
263 }
264
265 #[test]
266 fn test_request_to_wire_auth_form() {
267 let new_request = http::Request::builder()
268 .method("CONNECT")
269 .uri("https://pingora.org:123/")
270 .header("Foo", "Bar")
271 .body(())
272 .unwrap();
273 let (new_request, _) = new_request.into_parts();
274 let wire = http_req_header_to_wire_auth_form(&new_request);
275 assert_eq!(
276 &b"CONNECT pingora.org:123 HTTP/1.1\r\nfoo: Bar\r\n\r\n"[..],
277 &wire
278 );
279 }
280
281 #[test]
282 fn test_validate_connect_response() {
283 use crate::upstreams::peer::BasicPeer;
284
285 struct DummyUserData {
286 some_num: i32,
287 some_string: String,
288 }
289
290 let peer_no_data = BasicPeer::new("127.0.0.1:80");
291 let mut peer_with_data = peer_no_data.clone();
292 peer_with_data.options.proxy_digest_user_data_hook = Some(std::sync::Arc::new(
293 |_req: &http::request::Parts, _resp: &pingora_http::ResponseHeader| {
294 Some(Box::new(DummyUserData {
295 some_num: 42,
296 some_string: "test".to_string(),
297 }) as Box<dyn std::any::Any + Send + Sync>)
298 },
299 ));
300
301 let request = http::Request::builder()
302 .method("CONNECT")
303 .uri("https://example.com:443/")
304 .body(())
305 .unwrap();
306 let (req_header, _) = request.into_parts();
307
308 let resp = ResponseHeader::build(200, None).unwrap();
309 let proxy_digest =
310 validate_connect_response(Box::new(resp), &peer_with_data, &req_header).unwrap();
311 assert!(proxy_digest.user_data.is_some());
312 let user_data = proxy_digest
313 .user_data
314 .as_ref()
315 .unwrap()
316 .downcast_ref::<DummyUserData>()
317 .unwrap();
318 assert_eq!(user_data.some_num, 42);
319 assert_eq!(user_data.some_string, "test");
320
321 let resp = ResponseHeader::build(200, None).unwrap();
322 let proxy_digest =
323 validate_connect_response(Box::new(resp), &peer_no_data, &req_header).unwrap();
324 assert!(proxy_digest.user_data.is_none());
325
326 let resp = ResponseHeader::build(404, None).unwrap();
327 assert!(validate_connect_response(Box::new(resp), &peer_with_data, &req_header).is_err());
328
329 let mut resp = ResponseHeader::build(200, None).unwrap();
330 resp.append_header("content-length", 0).unwrap();
331 assert!(validate_connect_response(Box::new(resp), &peer_no_data, &req_header).is_ok());
332
333 let mut resp = ResponseHeader::build(200, None).unwrap();
334 resp.append_header("transfer-encoding", 0).unwrap();
335 assert!(validate_connect_response(Box::new(resp), &peer_no_data, &req_header).is_err());
336 }
337
338 #[tokio::test]
339 async fn test_connect_write_request() {
340 use crate::upstreams::peer::BasicPeer;
341
342 let wire = b"CONNECT pingora.org:123 HTTP/1.1\r\nhost: pingora.org:123\r\n\r\n";
343 let mock_io = Box::new(Builder::new().write(wire).build());
344
345 let headers: BTreeMap<String, Vec<u8>> = BTreeMap::new();
346 let req = generate_connect_header("pingora.org", 123, headers.iter()).unwrap();
347 let peer = BasicPeer::new("127.0.0.1:123");
348 assert!(connect(mock_io, &req, &peer).await.is_err());
350
351 let to_wire = b"CONNECT pingora.org:123 HTTP/1.1\r\nhost: pingora.org:123\r\n\r\n";
352 let from_wire = b"HTTP/1.1 200 OK\r\n\r\n";
353 let mock_io = Box::new(Builder::new().write(to_wire).read(from_wire).build());
354
355 let req = generate_connect_header("pingora.org", 123, headers.iter()).unwrap();
356 let result = connect(mock_io, &req, &peer).await;
357 assert!(result.is_ok());
358 }
359}