conjure_runtime_raw/service/proxy/
connector.rs1use crate::service::proxy::{HttpProxyConfig, ProxyConfig};
15use bytes::Bytes;
16use futures::future::{self, BoxFuture};
17use http::header::{HOST, PROXY_AUTHORIZATION};
18use http::uri::Scheme;
19use http::{HeaderValue, Method, Request, StatusCode, Uri, Version};
20use http_body_util::Empty;
21use hyper::client::conn;
22use hyper::rt::{Read, ReadBufCursor, Write};
23use hyper_util::client::legacy::connect::{Connected, Connection};
24use std::convert::TryFrom;
25use std::error;
26use std::fmt;
27use std::future::Future;
28use std::io;
29use std::pin::{pin, Pin};
30use std::task::{Context, Poll};
31use tower_layer::Layer;
32use tower_service::Service;
33
34pub struct ProxyConnectorLayer {
40 config: Option<HttpProxyConfig>,
41}
42
43impl ProxyConnectorLayer {
44 pub fn new(config: &ProxyConfig) -> ProxyConnectorLayer {
45 let config = match config {
46 ProxyConfig::Http(config) => Some(config.clone()),
47 _ => None,
48 };
49
50 ProxyConnectorLayer { config }
51 }
52}
53
54impl<S> Layer<S> for ProxyConnectorLayer {
55 type Service = ProxyConnectorService<S>;
56
57 fn layer(&self, inner: S) -> Self::Service {
58 ProxyConnectorService {
59 inner,
60 config: self.config.clone(),
61 }
62 }
63}
64
65#[derive(Clone)]
66pub struct ProxyConnectorService<S> {
67 inner: S,
68 config: Option<HttpProxyConfig>,
69}
70
71impl<S> Service<Uri> for ProxyConnectorService<S>
72where
73 S: Service<Uri> + Send,
74 S::Response: Read + Write + Unpin + Send + 'static,
75 S::Error: Into<Box<dyn error::Error + Sync + Send>>,
76 S::Future: Send + 'static,
77{
78 type Response = ProxyConnection<S::Response>;
79 type Error = Box<dyn error::Error + Sync + Send>;
80 type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;
81
82 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
83 self.inner.poll_ready(cx).map_err(Into::into)
84 }
85
86 fn call(&mut self, req: Uri) -> Self::Future {
87 let config = match &self.config {
88 Some(proxy) => proxy.clone(),
89 None => {
90 let connect = self.inner.call(req);
91 return Box::pin(async move {
92 let stream = connect.await.map_err(Into::into)?;
93 Ok(ProxyConnection {
94 stream,
95 proxy: false,
96 })
97 });
98 }
99 };
100
101 let connect = self.inner.call(config.uri.clone());
102 Box::pin(async move {
103 let stream = connect.await.map_err(Into::into)?;
104
105 if req.scheme() == Some(&Scheme::HTTP) {
106 Ok(ProxyConnection {
107 stream,
108 proxy: true,
109 })
110 } else {
111 let stream = connect_https(stream, req, config).await?;
112 Ok(ProxyConnection {
113 stream,
114 proxy: false,
115 })
116 }
117 })
118 }
119}
120
121async fn connect_https<T>(
122 stream: T,
123 uri: Uri,
124 config: HttpProxyConfig,
125) -> Result<T, Box<dyn error::Error + Sync + Send>>
126where
127 T: Read + Write + Send + Unpin + 'static,
128{
129 let (mut sender, mut conn) = conn::http1::handshake(stream).await?;
130
131 let host = uri.host().ok_or("host missing from URI")?;
132 let authority = format!("{}:{}", host, uri.port_u16().unwrap_or(443));
133 let authority_uri = Uri::try_from(authority.clone()).unwrap();
134 let host = HeaderValue::try_from(authority).unwrap();
135
136 let mut request = Request::new(Empty::<Bytes>::new());
137 *request.method_mut() = Method::CONNECT;
138 *request.uri_mut() = authority_uri;
139 *request.version_mut() = Version::HTTP_11;
140 request.headers_mut().insert(HOST, host);
141 if let Some(credentials) = config.credentials {
142 request
143 .headers_mut()
144 .insert(PROXY_AUTHORIZATION, credentials);
145 }
146
147 let mut response = pin!(sender.send_request(request));
148 let response = future::poll_fn(|cx| {
149 let _ = conn.poll_without_shutdown(cx)?;
150 response.as_mut().poll(cx)
151 })
152 .await?;
153
154 if !response.status().is_success() {
155 return Err(ProxyTunnelError {
156 status: response.status(),
157 }
158 .into());
159 }
160
161 Ok(conn.into_parts().io)
162}
163
164#[derive(Debug)]
165pub struct ProxyConnection<T> {
166 stream: T,
167 proxy: bool,
168}
169
170impl<T> Read for ProxyConnection<T>
171where
172 T: Read + Unpin,
173{
174 fn poll_read(
175 mut self: Pin<&mut Self>,
176 cx: &mut Context<'_>,
177 buf: ReadBufCursor<'_>,
178 ) -> Poll<io::Result<()>> {
179 Pin::new(&mut self.stream).poll_read(cx, buf)
180 }
181}
182
183impl<T> Write for ProxyConnection<T>
184where
185 T: Write + Unpin,
186{
187 fn poll_write(
188 mut self: Pin<&mut Self>,
189 cx: &mut Context<'_>,
190 buf: &[u8],
191 ) -> Poll<io::Result<usize>> {
192 Pin::new(&mut self.stream).poll_write(cx, buf)
193 }
194
195 fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
196 Pin::new(&mut self.stream).poll_flush(cx)
197 }
198
199 fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
200 Pin::new(&mut self.stream).poll_shutdown(cx)
201 }
202}
203
204impl<T> Connection for ProxyConnection<T>
205where
206 T: Connection,
207{
208 fn connected(&self) -> Connected {
209 self.stream.connected().proxy(self.proxy)
210 }
211}
212
213#[derive(Debug)]
214struct ProxyTunnelError {
215 status: StatusCode,
216}
217
218impl fmt::Display for ProxyTunnelError {
219 fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
220 write!(fmt, "got status {} from HTTPS proxy", self.status)
221 }
222}
223
224impl error::Error for ProxyTunnelError {}
225
226#[cfg(test)]
227mod test {
228 use super::*;
229 use crate::config::{self, BasicCredentials, HostAndPort};
230 use hyper_util::rt::TokioIo;
231 use tower_util::ServiceExt;
232
233 struct MockConnection(TokioIo<tokio_test::io::Mock>);
234
235 impl Read for MockConnection {
236 fn poll_read(
237 mut self: Pin<&mut Self>,
238 cx: &mut Context<'_>,
239 buf: ReadBufCursor<'_>,
240 ) -> Poll<io::Result<()>> {
241 Pin::new(&mut self.0).poll_read(cx, buf)
242 }
243 }
244
245 impl Write for MockConnection {
246 fn poll_write(
247 mut self: Pin<&mut Self>,
248 cx: &mut Context<'_>,
249 buf: &[u8],
250 ) -> Poll<io::Result<usize>> {
251 Pin::new(&mut self.0).poll_write(cx, buf)
252 }
253
254 fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
255 Pin::new(&mut self.0).poll_flush(cx)
256 }
257
258 fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
259 Pin::new(&mut self.0).poll_shutdown(cx)
260 }
261 }
262
263 impl Connection for MockConnection {
264 fn connected(&self) -> Connected {
265 Connected::new()
266 }
267 }
268
269 #[tokio::test]
270 async fn unproxied() {
271 let service = ProxyConnectorLayer::new(&ProxyConfig::Direct).layer(tower_util::service_fn(
272 |uri: Uri| async move {
273 assert_eq!(uri, "http://foobar.com");
274 Ok::<_, Box<dyn error::Error + Sync + Send>>(MockConnection(TokioIo::new(
275 tokio_test::io::Builder::new().build(),
276 )))
277 },
278 ));
279
280 let conn = service
281 .oneshot("http://foobar.com".parse().unwrap())
282 .await
283 .unwrap();
284
285 assert!(!conn.connected().is_proxied())
286 }
287
288 #[tokio::test]
289 async fn http_proxied_http() {
290 let config = ProxyConfig::from_config(&config::ProxyConfig::Http(
291 config::HttpProxyConfig::builder()
292 .host_and_port(HostAndPort::new("127.0.0.1", 1234))
293 .build(),
294 ))
295 .unwrap();
296 let service = ProxyConnectorLayer::new(&config).layer(tower_util::service_fn(
297 |uri: Uri| async move {
298 assert_eq!(uri, "http://127.0.0.1:1234");
299 Ok::<_, Box<dyn error::Error + Sync + Send>>(MockConnection(TokioIo::new(
300 tokio_test::io::Builder::new().build(),
301 )))
302 },
303 ));
304
305 let conn = service
306 .oneshot("http://foobar.com".parse().unwrap())
307 .await
308 .unwrap();
309
310 assert!(conn.connected().is_proxied())
311 }
312
313 #[tokio::test]
314 async fn http_proxied_https() {
315 let config = ProxyConfig::from_config(&config::ProxyConfig::Http(
316 config::HttpProxyConfig::builder()
317 .host_and_port(HostAndPort::new("127.0.0.1", 1234))
318 .credentials(Some(BasicCredentials::new("admin", "hunter2")))
319 .build(),
320 ))
321 .unwrap();
322 let service = ProxyConnectorLayer::new(&config).layer(tower_util::service_fn(
323 |uri: Uri| async move {
324 assert_eq!(uri, "http://127.0.0.1:1234");
325 let mut builder = tokio_test::io::Builder::new();
326 builder.write(
327 b"CONNECT foobar.com:443 HTTP/1.1\r\n\
328 host: foobar.com:443\r\n\
329 proxy-authorization: Basic YWRtaW46aHVudGVyMg==\r\n\r\n",
330 );
331 builder.read(b"HTTP/1.1 200 OK\r\n\r\n");
332 Ok::<_, Box<dyn error::Error + Sync + Send>>(MockConnection(TokioIo::new(
333 builder.build(),
334 )))
335 },
336 ));
337
338 let conn = service
339 .oneshot("https://admin:hunter2@foobar.com/fizzbuzz".parse().unwrap())
340 .await
341 .unwrap();
342
343 assert!(!conn.connected().is_proxied())
344 }
345
346 #[tokio::test]
347 async fn http_proxied_https_error() {
348 let config = ProxyConfig::from_config(&config::ProxyConfig::Http(
349 config::HttpProxyConfig::builder()
350 .host_and_port(HostAndPort::new("127.0.0.1", 1234))
351 .build(),
352 ))
353 .unwrap();
354 let service = ProxyConnectorLayer::new(&config).layer(tower_util::service_fn(
355 |uri: Uri| async move {
356 assert_eq!(uri, "http://127.0.0.1:1234");
357 let mut builder = tokio_test::io::Builder::new();
358 builder.write(
359 b"CONNECT foobar.com:443 HTTP/1.1\r\n\
360 host: foobar.com:443\r\n\r\n",
361 );
362 builder.read(b"HTTP/1.1 401 Unauthorized\r\n\r\n");
363 Ok::<_, Box<dyn error::Error + Sync + Send>>(MockConnection(TokioIo::new(
364 builder.build(),
365 )))
366 },
367 ));
368
369 let err = service
370 .oneshot("https://admin:hunter2@foobar.com/fizzbuzz".parse().unwrap())
371 .await
372 .err()
373 .unwrap()
374 .downcast::<ProxyTunnelError>()
375 .unwrap();
376
377 assert_eq!(err.status, StatusCode::UNAUTHORIZED);
378 }
379}