use tracing::debug;
use {
http::{
HeaderMap, Response, StatusCode,
request::{Parts, Request},
},
hyper::{Body, client::conn::Parts as ConnParts, upgrade::OnUpgrade},
std::net::SocketAddr,
tokio::{io::copy_bidirectional, net::TcpStream, task::JoinHandle},
tracing::{error, info, warn},
};
#[derive(Debug)]
pub struct PushpinRedirectInfo {
pub backend_name: String,
pub request_info: Option<PushpinRedirectRequestInfo>,
}
#[derive(Debug)]
pub struct PushpinRedirectRequestInfo {
pub method: String,
pub scheme: Option<String>,
pub authority: Option<String>,
pub path_and_query: Option<String>,
pub headers: HeaderMap,
}
impl PushpinRedirectRequestInfo {
pub fn from_parts(parts: &Parts) -> Self {
PushpinRedirectRequestInfo {
method: parts.method.to_string(),
scheme: parts.uri.scheme().map(|x| x.to_string()),
authority: parts.uri.authority().map(|x| x.to_string()),
path_and_query: parts.uri.path_and_query().map(|p| p.to_string()),
headers: parts.headers.clone(),
}
}
}
const PROTECTED_REQ_HEADERS: &[&str] = &[
"host",
"connection",
"sec-websocket-version",
"sec-websocket-key",
"upgrade",
"pushpin-route",
"content-length",
"content-range",
"expect",
"fastly-ff",
"proxy-authenticate",
"proxy-authorization",
"te",
"trailer",
"transfer-encoding",
"cdn-loop",
];
pub async fn proxy_through_pushpin(
pushpin_addr: SocketAddr,
backend_name: String,
redirect_request_info: Option<PushpinRedirectRequestInfo>,
original_request_info: PushpinRedirectRequestInfo,
original_request_body: Body,
original_request_on_upgrade: OnUpgrade,
) -> Response<Body> {
debug!("Proxying through Pushpin backend '{}'.", backend_name);
let pushpin_stream = match TcpStream::connect(pushpin_addr).await {
Ok(str) => str,
Err(e) => {
error!("Could not connect to Pushpin: {e}.");
return build_error_response(
StatusCode::INTERNAL_SERVER_ERROR,
"Could not connect to Pushpin",
);
}
};
let (path_and_query, method) = if let Some(ref info) = redirect_request_info {
(
info.path_and_query.as_deref().unwrap_or(""),
info.method.as_str(),
)
} else {
(
original_request_info
.path_and_query
.as_deref()
.unwrap_or(""),
original_request_info.method.as_str(),
)
};
let mut req = Request::builder().method(method).uri(path_and_query);
if let Some(redirect_request_info) = redirect_request_info {
for (name, value) in &original_request_info.headers {
if PROTECTED_REQ_HEADERS
.iter()
.any(|h| h.eq_ignore_ascii_case(name.as_str()))
{
req = req.header(name, value);
}
}
for (name, value) in &redirect_request_info.headers {
if !PROTECTED_REQ_HEADERS
.iter()
.any(|h| h.eq_ignore_ascii_case(name.as_str()))
{
req = req.header(name, value);
}
}
} else {
for (name, value) in &original_request_info.headers {
req = req.header(name, value);
}
}
req = req.header("host", pushpin_addr.to_string());
req = req.header("pushpin-route", backend_name.to_string());
let req = match req.body(original_request_body) {
Ok(req) => req,
Err(e) => {
error!("Failed to build Pushpin proxy request: {}", e);
return build_error_response(
StatusCode::INTERNAL_SERVER_ERROR,
"Could not build proxy request",
);
}
};
let (mut sender, conn) = match hyper::client::conn::Builder::new()
.handshake(pushpin_stream)
.await
{
Ok(res) => res,
Err(e) => {
error!("Pushpin handshake failed: {}", e);
return build_error_response(
StatusCode::INTERNAL_SERVER_ERROR,
format!("Could not connect to upstream service: {e}"),
);
}
};
let conn_fut = tokio::spawn(conn.without_shutdown());
let upstream_resp = match sender.send_request(req).await {
Ok(proxy_resp) => {
info!(
"Received response from Pushpin backend '{}'. Proxying response.",
backend_name
);
proxy_resp
}
Err(e) => {
error!("Pushpin proxy request failed: {}", e);
return build_error_response(
StatusCode::BAD_GATEWAY,
format!("Pushpin request failed: {e}"),
);
}
};
if upstream_resp.status() == StatusCode::SWITCHING_PROTOCOLS {
debug!("Pushpin responded with `101 Switching Protocols`, attempting upgrade...");
tokio::spawn(proxy_upgraded_connection(
original_request_on_upgrade,
conn_fut,
));
}
upstream_resp
}
async fn proxy_upgraded_connection(
downstream_req_on_upgrade: OnUpgrade,
upstream_conn_fut: JoinHandle<Result<ConnParts<TcpStream>, hyper::Error>>,
) {
let mut downstream_upgraded = match downstream_req_on_upgrade.await {
Ok(upgraded) => upgraded,
Err(e) => {
error!("Downstream client upgrade failed: {}", e);
return;
}
};
debug!("Downstream client connection upgraded.");
let mut upstream_parts = match upstream_conn_fut.await {
Ok(Ok(parts)) => parts,
Ok(Err(e)) => {
error!("Upstream connection error: {}", e);
return;
}
Err(e) => {
warn!("Upstream connection task failed: {}", e);
return;
}
};
debug!("Upstream connection IO stream obtained.");
match copy_bidirectional(&mut downstream_upgraded, &mut upstream_parts.io).await {
Ok((from_client, from_server)) => {
info!(
"Upgraded proxy connection finished gracefully. Bytes transferred: client->server: {}, server->client: {}",
from_client, from_server
);
}
Err(e) => {
error!("Upgraded proxy I/O error: {e}");
}
}
}
fn build_error_response(status: StatusCode, message: impl ToString) -> Response<Body> {
Response::builder()
.status(status)
.body(Body::from(format!("Error: {}", message.to_string())))
.expect("Could not build error response")
}