use std::convert::TryInto;
use std::net::IpAddr;
use std::str::FromStr;
use std::time::Duration;
use anyhow::Result;
use anyhow::anyhow;
use log::*;
use http::{header::HeaderName, StatusCode};
use hyper::body::HttpBody;
use hyper::client::connect::Connect;
use hyper::header::{HeaderMap, HeaderValue};
use hyper::{client::HttpConnector, header, Body, Client, Request, Response, Uri};
pub const PROXY_TIMEOUT: Duration = Duration::from_secs(60);
const HOP_HEADERS: &[HeaderName] = &[
header::CONNECTION,
header::PROXY_AUTHENTICATE,
header::PROXY_AUTHORIZATION,
header::TE,
header::TRAILER,
header::TRANSFER_ENCODING,
header::UPGRADE,
];
fn is_hop_header(name: &HeaderName) -> bool {
HOP_HEADERS.iter().any(|h| h == name)
}
fn remove_hop_headers(headers: &HeaderMap<HeaderValue>) -> HeaderMap<HeaderValue> {
let mut result = HeaderMap::new();
for (k, v) in headers.iter() {
if !is_hop_header(k) {
result.append(k.clone(), v.clone());
}
}
result
}
fn copy_upgrade_headers(
old_headers: &HeaderMap<HeaderValue>,
new_headers: &mut HeaderMap<HeaderValue>,
) -> Result<bool> {
let mut is_upgrade = false;
if let Some(conn) = old_headers.get(header::CONNECTION) {
let conn_str = conn.to_str()?.to_lowercase();
if conn_str.split(',').map(str::trim).any(|x| x == "upgrade") {
if let Some(upgrade) = old_headers.get(header::UPGRADE) {
new_headers.insert(header::CONNECTION, "Upgrade".try_into()?);
new_headers.insert(header::UPGRADE, upgrade.clone());
is_upgrade = true;
}
}
}
Ok(is_upgrade)
}
fn forward_uri<B>(forward_url: &str, req: &Request<B>) -> Result<Uri> {
let forward_uri = match req.uri().query() {
Some(query) => format!("{}{}?{}", forward_url, req.uri().path(), query),
None => format!("{}{}", forward_url, req.uri().path()),
};
Ok(Uri::from_str(forward_uri.as_str())?)
}
fn create_proxied_request<B: std::default::Default>(
client_ip: IpAddr,
forward_url: &str,
request: Request<B>,
) -> Result<(Request<B>, Option<Request<B>>)> {
let mut builder = Request::builder()
.method(request.method())
.uri(forward_uri(forward_url, &request)?)
.version(hyper::Version::HTTP_11);
let old_headers = request.headers();
let new_headers = builder.headers_mut().unwrap();
*new_headers = remove_hop_headers(old_headers);
if let header::Entry::Vacant(entry) = new_headers.entry(header::HOST) {
if let Some(authority) = request.uri().authority() {
entry.insert(authority.as_str().parse()?);
}
}
let mut cookie_concat = vec![];
for cookie in new_headers.get_all(header::COOKIE) {
if !cookie_concat.is_empty() {
cookie_concat.extend(b"; ");
}
cookie_concat.extend_from_slice(cookie.as_bytes());
}
if !cookie_concat.is_empty() {
new_headers.insert(header::COOKIE, cookie_concat.try_into()?);
}
let x_forwarded_for_header_name = "x-forwarded-for";
match new_headers.entry(x_forwarded_for_header_name) {
header::Entry::Vacant(entry) => {
entry.insert(client_ip.to_string().parse()?);
}
header::Entry::Occupied(mut entry) => {
let addr = format!("{}, {}", entry.get().to_str()?, client_ip);
entry.insert(addr.parse()?);
}
}
new_headers.insert(
HeaderName::from_bytes(b"x-forwarded-proto")?,
"https".try_into()?,
);
let is_upgrade = copy_upgrade_headers(old_headers, new_headers)?;
if is_upgrade {
Ok((builder.body(B::default())?, Some(request)))
} else {
Ok((builder.body(request.into_body())?, None))
}
}
async fn create_proxied_response<B: std::default::Default + Send + 'static>(
mut response: Response<Body>,
upgrade_request: Option<Request<B>>,
) -> Result<Response<Body>> {
let old_headers = response.headers();
let mut new_headers = remove_hop_headers(old_headers);
copy_upgrade_headers(old_headers, &mut new_headers)?;
if response.status() == StatusCode::SWITCHING_PROTOCOLS {
if let Some(mut req) = upgrade_request {
let mut res_upgraded = hyper::upgrade::on(response).await?;
tokio::spawn(async move {
match hyper::upgrade::on(&mut req).await {
Ok(mut req_upgraded) => {
if let Err(e) =
tokio::io::copy_bidirectional(&mut req_upgraded, &mut res_upgraded)
.await
{
warn!("Error copying data in upgraded request: {}", e);
}
}
Err(e) => {
warn!(
"Could not upgrade client request when switching protocols: {}",
e
);
}
}
});
let mut new_res = Response::builder().status(StatusCode::SWITCHING_PROTOCOLS);
*new_res.headers_mut().unwrap() = new_headers;
Ok(new_res.body(Default::default())?)
} else {
Err(anyhow!("Switching protocols but not an upgrade request"))
}
} else {
*response.headers_mut() = new_headers;
Ok(response)
}
}
pub async fn call<'a, C, B>(
client_ip: IpAddr,
forward_uri: &str,
request: Request<B>,
client: &'a Client<C, B>,
) -> Result<Response<Body>> where
C: Connect + Clone + Send + Sync + 'static,
B: HttpBody + Send + std::default::Default + std::fmt::Debug + 'static,
B::Data: Send,
B::Error: Into<Box<dyn std::error::Error + Send + Sync>>,
{
let (proxied_request, upgrade_request) =
create_proxied_request(client_ip, forward_uri, request)?;
trace!("Proxied request: {:?}", proxied_request);
let mut connector = HttpConnector::new();
connector.set_connect_timeout(Some(PROXY_TIMEOUT));
let response = client.request(proxied_request).await?;
trace!("Inner response: {:?}", response);
let proxied_response = create_proxied_response(response, upgrade_request).await?;
Ok(proxied_response)
}