conjure_runtime_raw/service/proxy/
connector.rs

1// Copyright 2020 Palantir Technologies, Inc.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14use crate::service::proxy::{HttpProxyConfig, ProxyConfig};
15use bytes::Bytes;
16use futures::future::{self, BoxFuture};
17use http::header::{HOST, PROXY_AUTHORIZATION};
18use http::uri::Scheme;
19use http::{HeaderValue, Method, Request, StatusCode, Uri, Version};
20use http_body_util::Empty;
21use hyper::client::conn;
22use hyper::rt::{Read, ReadBufCursor, Write};
23use hyper_util::client::legacy::connect::{Connected, Connection};
24use std::convert::TryFrom;
25use std::error;
26use std::fmt;
27use std::future::Future;
28use std::io;
29use std::pin::{pin, Pin};
30use std::task::{Context, Poll};
31use tower_layer::Layer;
32use tower_service::Service;
33
34/// A connector layer which handles socket-level setup for HTTP proxies.
35///
36/// For http requests, we just connect to the proxy server and tell hyper that the connection is proxied so it can
37/// adjust the HTTP request. For https requests, we create a tunnel through the proxy server to the target server via a
38/// CONNECT request that is then used upstream for the TLS handshake.
39pub struct ProxyConnectorLayer {
40    config: Option<HttpProxyConfig>,
41}
42
43impl ProxyConnectorLayer {
44    pub fn new(config: &ProxyConfig) -> ProxyConnectorLayer {
45        let config = match config {
46            ProxyConfig::Http(config) => Some(config.clone()),
47            _ => None,
48        };
49
50        ProxyConnectorLayer { config }
51    }
52}
53
54impl<S> Layer<S> for ProxyConnectorLayer {
55    type Service = ProxyConnectorService<S>;
56
57    fn layer(&self, inner: S) -> Self::Service {
58        ProxyConnectorService {
59            inner,
60            config: self.config.clone(),
61        }
62    }
63}
64
65#[derive(Clone)]
66pub struct ProxyConnectorService<S> {
67    inner: S,
68    config: Option<HttpProxyConfig>,
69}
70
71impl<S> Service<Uri> for ProxyConnectorService<S>
72where
73    S: Service<Uri> + Send,
74    S::Response: Read + Write + Unpin + Send + 'static,
75    S::Error: Into<Box<dyn error::Error + Sync + Send>>,
76    S::Future: Send + 'static,
77{
78    type Response = ProxyConnection<S::Response>;
79    type Error = Box<dyn error::Error + Sync + Send>;
80    type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;
81
82    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
83        self.inner.poll_ready(cx).map_err(Into::into)
84    }
85
86    fn call(&mut self, req: Uri) -> Self::Future {
87        let config = match &self.config {
88            Some(proxy) => proxy.clone(),
89            None => {
90                let connect = self.inner.call(req);
91                return Box::pin(async move {
92                    let stream = connect.await.map_err(Into::into)?;
93                    Ok(ProxyConnection {
94                        stream,
95                        proxy: false,
96                    })
97                });
98            }
99        };
100
101        let connect = self.inner.call(config.uri.clone());
102        Box::pin(async move {
103            let stream = connect.await.map_err(Into::into)?;
104
105            if req.scheme() == Some(&Scheme::HTTP) {
106                Ok(ProxyConnection {
107                    stream,
108                    proxy: true,
109                })
110            } else {
111                let stream = connect_https(stream, req, config).await?;
112                Ok(ProxyConnection {
113                    stream,
114                    proxy: false,
115                })
116            }
117        })
118    }
119}
120
121async fn connect_https<T>(
122    stream: T,
123    uri: Uri,
124    config: HttpProxyConfig,
125) -> Result<T, Box<dyn error::Error + Sync + Send>>
126where
127    T: Read + Write + Send + Unpin + 'static,
128{
129    let (mut sender, mut conn) = conn::http1::handshake(stream).await?;
130
131    let host = uri.host().ok_or("host missing from URI")?;
132    let authority = format!("{}:{}", host, uri.port_u16().unwrap_or(443));
133    let authority_uri = Uri::try_from(authority.clone()).unwrap();
134    let host = HeaderValue::try_from(authority).unwrap();
135
136    let mut request = Request::new(Empty::<Bytes>::new());
137    *request.method_mut() = Method::CONNECT;
138    *request.uri_mut() = authority_uri;
139    *request.version_mut() = Version::HTTP_11;
140    request.headers_mut().insert(HOST, host);
141    if let Some(credentials) = config.credentials {
142        request
143            .headers_mut()
144            .insert(PROXY_AUTHORIZATION, credentials);
145    }
146
147    let mut response = pin!(sender.send_request(request));
148    let response = future::poll_fn(|cx| {
149        let _ = conn.poll_without_shutdown(cx)?;
150        response.as_mut().poll(cx)
151    })
152    .await?;
153
154    if !response.status().is_success() {
155        return Err(ProxyTunnelError {
156            status: response.status(),
157        }
158        .into());
159    }
160
161    Ok(conn.into_parts().io)
162}
163
164#[derive(Debug)]
165pub struct ProxyConnection<T> {
166    stream: T,
167    proxy: bool,
168}
169
170impl<T> Read for ProxyConnection<T>
171where
172    T: Read + Unpin,
173{
174    fn poll_read(
175        mut self: Pin<&mut Self>,
176        cx: &mut Context<'_>,
177        buf: ReadBufCursor<'_>,
178    ) -> Poll<io::Result<()>> {
179        Pin::new(&mut self.stream).poll_read(cx, buf)
180    }
181}
182
183impl<T> Write for ProxyConnection<T>
184where
185    T: Write + Unpin,
186{
187    fn poll_write(
188        mut self: Pin<&mut Self>,
189        cx: &mut Context<'_>,
190        buf: &[u8],
191    ) -> Poll<io::Result<usize>> {
192        Pin::new(&mut self.stream).poll_write(cx, buf)
193    }
194
195    fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
196        Pin::new(&mut self.stream).poll_flush(cx)
197    }
198
199    fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
200        Pin::new(&mut self.stream).poll_shutdown(cx)
201    }
202}
203
204impl<T> Connection for ProxyConnection<T>
205where
206    T: Connection,
207{
208    fn connected(&self) -> Connected {
209        self.stream.connected().proxy(self.proxy)
210    }
211}
212
213#[derive(Debug)]
214struct ProxyTunnelError {
215    status: StatusCode,
216}
217
218impl fmt::Display for ProxyTunnelError {
219    fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
220        write!(fmt, "got status {} from HTTPS proxy", self.status)
221    }
222}
223
224impl error::Error for ProxyTunnelError {}
225
226#[cfg(test)]
227mod test {
228    use super::*;
229    use crate::config::{self, BasicCredentials, HostAndPort};
230    use hyper_util::rt::TokioIo;
231    use tower_util::ServiceExt;
232
233    struct MockConnection(TokioIo<tokio_test::io::Mock>);
234
235    impl Read for MockConnection {
236        fn poll_read(
237            mut self: Pin<&mut Self>,
238            cx: &mut Context<'_>,
239            buf: ReadBufCursor<'_>,
240        ) -> Poll<io::Result<()>> {
241            Pin::new(&mut self.0).poll_read(cx, buf)
242        }
243    }
244
245    impl Write for MockConnection {
246        fn poll_write(
247            mut self: Pin<&mut Self>,
248            cx: &mut Context<'_>,
249            buf: &[u8],
250        ) -> Poll<io::Result<usize>> {
251            Pin::new(&mut self.0).poll_write(cx, buf)
252        }
253
254        fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
255            Pin::new(&mut self.0).poll_flush(cx)
256        }
257
258        fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
259            Pin::new(&mut self.0).poll_shutdown(cx)
260        }
261    }
262
263    impl Connection for MockConnection {
264        fn connected(&self) -> Connected {
265            Connected::new()
266        }
267    }
268
269    #[tokio::test]
270    async fn unproxied() {
271        let service = ProxyConnectorLayer::new(&ProxyConfig::Direct).layer(tower_util::service_fn(
272            |uri: Uri| async move {
273                assert_eq!(uri, "http://foobar.com");
274                Ok::<_, Box<dyn error::Error + Sync + Send>>(MockConnection(TokioIo::new(
275                    tokio_test::io::Builder::new().build(),
276                )))
277            },
278        ));
279
280        let conn = service
281            .oneshot("http://foobar.com".parse().unwrap())
282            .await
283            .unwrap();
284
285        assert!(!conn.connected().is_proxied())
286    }
287
288    #[tokio::test]
289    async fn http_proxied_http() {
290        let config = ProxyConfig::from_config(&config::ProxyConfig::Http(
291            config::HttpProxyConfig::builder()
292                .host_and_port(HostAndPort::new("127.0.0.1", 1234))
293                .build(),
294        ))
295        .unwrap();
296        let service = ProxyConnectorLayer::new(&config).layer(tower_util::service_fn(
297            |uri: Uri| async move {
298                assert_eq!(uri, "http://127.0.0.1:1234");
299                Ok::<_, Box<dyn error::Error + Sync + Send>>(MockConnection(TokioIo::new(
300                    tokio_test::io::Builder::new().build(),
301                )))
302            },
303        ));
304
305        let conn = service
306            .oneshot("http://foobar.com".parse().unwrap())
307            .await
308            .unwrap();
309
310        assert!(conn.connected().is_proxied())
311    }
312
313    #[tokio::test]
314    async fn http_proxied_https() {
315        let config = ProxyConfig::from_config(&config::ProxyConfig::Http(
316            config::HttpProxyConfig::builder()
317                .host_and_port(HostAndPort::new("127.0.0.1", 1234))
318                .credentials(Some(BasicCredentials::new("admin", "hunter2")))
319                .build(),
320        ))
321        .unwrap();
322        let service = ProxyConnectorLayer::new(&config).layer(tower_util::service_fn(
323            |uri: Uri| async move {
324                assert_eq!(uri, "http://127.0.0.1:1234");
325                let mut builder = tokio_test::io::Builder::new();
326                builder.write(
327                    b"CONNECT foobar.com:443 HTTP/1.1\r\n\
328                    host: foobar.com:443\r\n\
329                    proxy-authorization: Basic YWRtaW46aHVudGVyMg==\r\n\r\n",
330                );
331                builder.read(b"HTTP/1.1 200 OK\r\n\r\n");
332                Ok::<_, Box<dyn error::Error + Sync + Send>>(MockConnection(TokioIo::new(
333                    builder.build(),
334                )))
335            },
336        ));
337
338        let conn = service
339            .oneshot("https://admin:hunter2@foobar.com/fizzbuzz".parse().unwrap())
340            .await
341            .unwrap();
342
343        assert!(!conn.connected().is_proxied())
344    }
345
346    #[tokio::test]
347    async fn http_proxied_https_error() {
348        let config = ProxyConfig::from_config(&config::ProxyConfig::Http(
349            config::HttpProxyConfig::builder()
350                .host_and_port(HostAndPort::new("127.0.0.1", 1234))
351                .build(),
352        ))
353        .unwrap();
354        let service = ProxyConnectorLayer::new(&config).layer(tower_util::service_fn(
355            |uri: Uri| async move {
356                assert_eq!(uri, "http://127.0.0.1:1234");
357                let mut builder = tokio_test::io::Builder::new();
358                builder.write(
359                    b"CONNECT foobar.com:443 HTTP/1.1\r\n\
360                    host: foobar.com:443\r\n\r\n",
361                );
362                builder.read(b"HTTP/1.1 401 Unauthorized\r\n\r\n");
363                Ok::<_, Box<dyn error::Error + Sync + Send>>(MockConnection(TokioIo::new(
364                    builder.build(),
365                )))
366            },
367        ));
368
369        let err = service
370            .oneshot("https://admin:hunter2@foobar.com/fizzbuzz".parse().unwrap())
371            .await
372            .err()
373            .unwrap()
374            .downcast::<ProxyTunnelError>()
375            .unwrap();
376
377        assert_eq!(err.status, StatusCode::UNAUTHORIZED);
378    }
379}