use crate::service::proxy::hyper::{HttpProxyConfig, ProxyConfig};
use crate::{builder, Builder};
use bytes::Bytes;
use conjure_error::Error;
use futures::future::{self, BoxFuture};
use http::header::{HOST, PROXY_AUTHORIZATION};
use http::uri::Scheme;
use http::{HeaderValue, Method, Request, StatusCode, Uri, Version};
use http_body_util::Empty;
use hyper::client::conn;
use hyper::rt::{Read, ReadBufCursor, Write};
use hyper_util::client::legacy::connect::{Connected, Connection};
use std::convert::TryFrom;
use std::error;
use std::fmt;
use std::future::Future;
use std::io;
use std::pin::{pin, Pin};
use std::task::{Context, Poll};
use tower_layer::Layer;
use tower_service::Service;
pub struct ProxyConnectorLayer {
config: Option<HttpProxyConfig>,
}
impl ProxyConnectorLayer {
pub fn new(builder: &Builder<builder::Complete>) -> Result<ProxyConnectorLayer, Error> {
let config = ProxyConfig::from_config(builder.get_proxy())?;
let config = match config {
ProxyConfig::Http(config) => Some(config.clone()),
_ => None,
};
Ok(ProxyConnectorLayer { config })
}
}
impl<S> Layer<S> for ProxyConnectorLayer {
type Service = ProxyConnectorService<S>;
fn layer(&self, inner: S) -> Self::Service {
ProxyConnectorService {
inner,
config: self.config.clone(),
}
}
}
#[derive(Clone)]
pub struct ProxyConnectorService<S> {
inner: S,
config: Option<HttpProxyConfig>,
}
impl<S> Service<Uri> for ProxyConnectorService<S>
where
S: Service<Uri> + Send,
S::Response: Read + Write + Unpin + Send + 'static,
S::Error: Into<Box<dyn error::Error + Sync + Send>>,
S::Future: Send + 'static,
{
type Response = ProxyConnection<S::Response>;
type Error = Box<dyn error::Error + Sync + Send>;
type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.inner.poll_ready(cx).map_err(Into::into)
}
fn call(&mut self, req: Uri) -> Self::Future {
let config = match &self.config {
Some(proxy) => proxy.clone(),
None => {
let connect = self.inner.call(req);
return Box::pin(async move {
let stream = connect.await.map_err(Into::into)?;
Ok(ProxyConnection {
stream,
proxy: false,
})
});
}
};
let connect = self.inner.call(config.uri.clone());
Box::pin(async move {
let stream = connect.await.map_err(Into::into)?;
if req.scheme() == Some(&Scheme::HTTP) {
Ok(ProxyConnection {
stream,
proxy: true,
})
} else {
let stream = connect_https(stream, req, config).await?;
Ok(ProxyConnection {
stream,
proxy: false,
})
}
})
}
}
async fn connect_https<T>(
stream: T,
uri: Uri,
config: HttpProxyConfig,
) -> Result<T, Box<dyn error::Error + Sync + Send>>
where
T: Read + Write + Send + Unpin + 'static,
{
let (mut sender, mut conn) = conn::http1::handshake(stream).await?;
let host = uri.host().ok_or("host missing from URI")?;
let authority = format!("{}:{}", host, uri.port_u16().unwrap_or(443));
let authority_uri = Uri::try_from(authority.clone()).unwrap();
let host = HeaderValue::try_from(authority).unwrap();
let mut request = Request::new(Empty::<Bytes>::new());
*request.method_mut() = Method::CONNECT;
*request.uri_mut() = authority_uri;
*request.version_mut() = Version::HTTP_11;
request.headers_mut().insert(HOST, host);
if let Some(credentials) = config.credentials {
request
.headers_mut()
.insert(PROXY_AUTHORIZATION, credentials);
}
let mut response = pin!(sender.send_request(request));
let response = future::poll_fn(|cx| {
let _ = conn.poll_without_shutdown(cx)?;
response.as_mut().poll(cx)
})
.await?;
if !response.status().is_success() {
return Err(ProxyTunnelError {
status: response.status(),
}
.into());
}
Ok(conn.into_parts().io)
}
#[derive(Debug)]
pub struct ProxyConnection<T> {
stream: T,
proxy: bool,
}
impl<T> Read for ProxyConnection<T>
where
T: Read + Unpin,
{
fn poll_read(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: ReadBufCursor<'_>,
) -> Poll<io::Result<()>> {
Pin::new(&mut self.stream).poll_read(cx, buf)
}
}
impl<T> Write for ProxyConnection<T>
where
T: Write + Unpin,
{
fn poll_write(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<io::Result<usize>> {
Pin::new(&mut self.stream).poll_write(cx, buf)
}
fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
Pin::new(&mut self.stream).poll_flush(cx)
}
fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
Pin::new(&mut self.stream).poll_shutdown(cx)
}
}
impl<T> Connection for ProxyConnection<T>
where
T: Connection,
{
fn connected(&self) -> Connected {
self.stream.connected().proxy(self.proxy)
}
}
#[derive(Debug)]
struct ProxyTunnelError {
status: StatusCode,
}
impl fmt::Display for ProxyTunnelError {
fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(fmt, "got status {} from HTTPS proxy", self.status)
}
}
impl error::Error for ProxyTunnelError {}
#[cfg(test)]
mod test {
use super::*;
use hyper_util::rt::TokioIo;
use tower_util::ServiceExt;
struct MockConnection(TokioIo<tokio_test::io::Mock>);
impl Read for MockConnection {
fn poll_read(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: ReadBufCursor<'_>,
) -> Poll<io::Result<()>> {
Pin::new(&mut self.0).poll_read(cx, buf)
}
}
impl Write for MockConnection {
fn poll_write(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<io::Result<usize>> {
Pin::new(&mut self.0).poll_write(cx, buf)
}
fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
Pin::new(&mut self.0).poll_flush(cx)
}
fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
Pin::new(&mut self.0).poll_shutdown(cx)
}
}
impl Connection for MockConnection {
fn connected(&self) -> Connected {
Connected::new()
}
}
#[tokio::test]
async fn unproxied() {
let service = ProxyConnectorLayer { config: None }.layer(tower_util::service_fn(
|uri: Uri| async move {
assert_eq!(uri, "http://foobar.com");
Ok::<_, Box<dyn error::Error + Sync + Send>>(MockConnection(TokioIo::new(
tokio_test::io::Builder::new().build(),
)))
},
));
let conn = service
.oneshot("http://foobar.com".parse().unwrap())
.await
.unwrap();
assert!(!conn.connected().is_proxied())
}
#[tokio::test]
async fn http_proxied_http() {
let service = ProxyConnectorLayer {
config: Some(HttpProxyConfig {
uri: "http://127.0.0.1:1234".parse().unwrap(),
credentials: None,
}),
}
.layer(tower_util::service_fn(|uri: Uri| async move {
assert_eq!(uri, "http://127.0.0.1:1234");
Ok::<_, Box<dyn error::Error + Sync + Send>>(MockConnection(TokioIo::new(
tokio_test::io::Builder::new().build(),
)))
}));
let conn = service
.oneshot("http://foobar.com".parse().unwrap())
.await
.unwrap();
assert!(conn.connected().is_proxied())
}
#[tokio::test]
async fn http_proxied_https() {
let service = ProxyConnectorLayer {
config: Some(HttpProxyConfig {
uri: "http://127.0.0.1:1234".parse().unwrap(),
credentials: Some("Basic YWRtaW46aHVudGVyMg==".parse().unwrap()),
}),
}
.layer(tower_util::service_fn(|uri: Uri| async move {
assert_eq!(uri, "http://127.0.0.1:1234");
let mut builder = tokio_test::io::Builder::new();
builder.write(
b"CONNECT foobar.com:443 HTTP/1.1\r\n\
host: foobar.com:443\r\n\
proxy-authorization: Basic YWRtaW46aHVudGVyMg==\r\n\r\n",
);
builder.read(b"HTTP/1.1 200 OK\r\n\r\n");
Ok::<_, Box<dyn error::Error + Sync + Send>>(MockConnection(TokioIo::new(
builder.build(),
)))
}));
let conn = service
.oneshot("https://admin:hunter2@foobar.com/fizzbuzz".parse().unwrap())
.await
.unwrap();
assert!(!conn.connected().is_proxied())
}
#[tokio::test]
async fn http_proxied_https_error() {
let service = ProxyConnectorLayer {
config: Some(HttpProxyConfig {
uri: "http://127.0.0.1:1234".parse().unwrap(),
credentials: None,
}),
}
.layer(tower_util::service_fn(|uri: Uri| async move {
assert_eq!(uri, "http://127.0.0.1:1234");
let mut builder = tokio_test::io::Builder::new();
builder.write(
b"CONNECT foobar.com:443 HTTP/1.1\r\n\
host: foobar.com:443\r\n\r\n",
);
builder.read(b"HTTP/1.1 401 Unauthorized\r\n\r\n");
Ok::<_, Box<dyn error::Error + Sync + Send>>(MockConnection(TokioIo::new(
builder.build(),
)))
}));
let err = service
.oneshot("https://admin:hunter2@foobar.com/fizzbuzz".parse().unwrap())
.await
.err()
.unwrap()
.downcast::<ProxyTunnelError>()
.unwrap();
assert_eq!(err.status, StatusCode::UNAUTHORIZED);
}
}