use std::{
marker::{PhantomData, Unpin},
task::{self, Poll},
};
use bytes::BytesMut;
use http::{HeaderMap, HeaderValue, Uri};
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
use tower::{BoxError, Service};
use super::Tunneling;
use crate::ext::UriExt;
#[derive(Debug)]
pub struct TunnelConnector<C> {
headers: Headers,
inner: C,
proxy_dst: Uri,
}
#[derive(Clone, Debug)]
enum Headers {
Empty,
Auth(HeaderValue),
Extra(HeaderMap),
}
#[derive(Debug)]
pub enum TunnelError {
ConnectFailed(BoxError),
Io(std::io::Error),
Parse(httparse::Error),
MissingHost,
ProxyAuthRequired,
TunnelUnexpectedEof,
TunnelUnsuccessful,
}
impl<C> TunnelConnector<C> {
pub fn new(proxy_dst: Uri, connector: C) -> Self {
Self {
headers: Headers::Empty,
inner: connector,
proxy_dst,
}
}
pub fn with_auth(mut self, mut auth: HeaderValue) -> Self {
auth.set_sensitive(true);
match self.headers {
Headers::Empty => {
self.headers = Headers::Auth(auth);
}
Headers::Auth(ref mut existing) => {
*existing = auth;
}
Headers::Extra(ref mut extra) => {
extra.insert(http::header::PROXY_AUTHORIZATION, auth);
}
}
self
}
pub fn with_headers(mut self, mut headers: HeaderMap) -> Self {
match self.headers {
Headers::Empty => {
self.headers = Headers::Extra(headers);
}
Headers::Auth(auth) => {
headers
.entry(http::header::PROXY_AUTHORIZATION)
.or_insert(auth);
self.headers = Headers::Extra(headers);
}
Headers::Extra(ref mut extra) => {
extra.extend(headers);
}
}
self
}
}
impl<C> Service<Uri> for TunnelConnector<C>
where
C: Service<Uri>,
C::Future: Send + 'static,
C::Response: AsyncRead + AsyncWrite + Unpin + Send + 'static,
C::Error: Into<BoxError>,
{
type Response = C::Response;
type Error = TunnelError;
type Future = Tunneling<C::Future, C::Response, Self::Error>;
#[inline]
fn poll_ready(&mut self, cx: &mut task::Context<'_>) -> Poll<Result<(), Self::Error>> {
self.inner
.poll_ready(cx)
.map_err(Into::into)
.map_err(TunnelError::ConnectFailed)
}
fn call(&mut self, dst: Uri) -> Self::Future {
let connecting = self.inner.call(self.proxy_dst.clone());
let headers = self.headers.clone();
Tunneling {
fut: Box::pin(async move {
let conn = connecting
.await
.map_err(Into::into)
.map_err(TunnelError::ConnectFailed)?;
tunnel(
conn,
dst.host().ok_or(TunnelError::MissingHost)?,
dst.port_or_default(),
&headers,
)
.await
}),
_marker: PhantomData,
}
}
}
async fn tunnel<T>(mut conn: T, host: &str, port: u16, headers: &Headers) -> Result<T, TunnelError>
where
T: AsyncRead + AsyncWrite + Unpin,
{
let mut buf = format!(
"\
CONNECT {host}:{port} HTTP/1.1\r\n\
Host: {host}:{port}\r\n\
"
)
.into_bytes();
match headers {
Headers::Auth(auth) => {
buf.extend_from_slice(b"Proxy-Authorization: ");
buf.extend_from_slice(auth.as_bytes());
buf.extend_from_slice(b"\r\n");
}
Headers::Extra(extra) => {
for (name, value) in extra {
buf.extend_from_slice(name.as_str().as_bytes());
buf.extend_from_slice(b": ");
buf.extend_from_slice(value.as_bytes());
buf.extend_from_slice(b"\r\n");
}
}
Headers::Empty => (),
}
buf.extend_from_slice(b"\r\n");
conn.write_all(&buf).await.map_err(TunnelError::Io)?;
conn.flush().await.map_err(TunnelError::Io)?;
let mut buf = BytesMut::with_capacity(8192);
loop {
if conn.read_buf(&mut buf).await.map_err(TunnelError::Io)? == 0 {
return Err(TunnelError::TunnelUnexpectedEof);
}
let mut headers = [httparse::EMPTY_HEADER; 64];
let mut res = httparse::Response::new(&mut headers);
match res.parse(&buf).map_err(TunnelError::Parse)? {
httparse::Status::Partial => continue,
httparse::Status::Complete(_) => match res.code {
Some(200) => return Ok(conn),
Some(407) => return Err(TunnelError::ProxyAuthRequired),
Some(_) | None => return Err(TunnelError::TunnelUnsuccessful),
},
}
}
}
impl std::fmt::Display for TunnelError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_str("tunnel error: ")?;
f.write_str(match self {
TunnelError::MissingHost => "missing destination host",
TunnelError::ProxyAuthRequired => "proxy authorization required",
TunnelError::Parse(_) => "invalid proxy response",
TunnelError::TunnelUnexpectedEof => "unexpected end of file",
TunnelError::TunnelUnsuccessful => "unsuccessful",
TunnelError::ConnectFailed(_) => "failed to create underlying connection",
TunnelError::Io(_) => "io error establishing tunnel",
})
}
}
impl std::error::Error for TunnelError {
fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
match self {
TunnelError::Io(e) => Some(e),
TunnelError::Parse(e) => Some(e),
TunnelError::ConnectFailed(e) => Some(&**e),
_ => None,
}
}
}