mod internal;
pub mod builder;
use crate::{
Error, HttpHandler, WebSocketHandler, builder::ProxyBuilder,
certificate_authority::CertificateAuthority,
};
use builder::{AddrOrListener, WantsAddr};
use hyper::service::service_fn;
use hyper_util::{
client::legacy::{Builder as ClientBuilder, Client, connect::Connect},
rt::{TokioExecutor, TokioIo},
server::conn::auto::Builder as ServerBuilder,
};
use internal::InternalProxy;
use std::sync::Arc;
use tokio::net::TcpListener;
use tokio_graceful::Shutdown;
use tokio_tungstenite::Connector;
use tracing::error;
pub struct Proxy<C, CA, H, W, F> {
al: AddrOrListener,
ca: Arc<CA>,
http_connector: C,
client: Option<ClientBuilder>,
http_handler: H,
websocket_handler: W,
websocket_connector: Option<Connector>,
server: Option<ServerBuilder<TokioExecutor>>,
graceful_shutdown: F,
}
impl Proxy<(), (), (), (), ()> {
pub fn builder() -> ProxyBuilder<WantsAddr> {
ProxyBuilder::new()
}
}
impl<C, CA, H, W, F> Proxy<C, CA, H, W, F>
where
C: Connect + Clone + Send + Sync + 'static,
CA: CertificateAuthority,
H: HttpHandler,
W: WebSocketHandler,
F: Future<Output = ()> + Send + 'static,
{
pub async fn start(self) -> Result<(), Error> {
let client = self
.client
.unwrap_or_else(|| {
let mut builder = Client::builder(TokioExecutor::new());
builder
.http1_title_case_headers(true)
.http1_preserve_header_case(true);
builder
})
.build(self.http_connector);
let server = self.server.unwrap_or_else(|| {
let mut builder = ServerBuilder::new(TokioExecutor::new());
builder
.http1()
.title_case_headers(true)
.preserve_header_case(true);
builder
});
let listener = match self.al {
AddrOrListener::Addr(addr) => TcpListener::bind(addr).await?,
AddrOrListener::Listener(listener) => listener,
};
let shutdown = Shutdown::new(self.graceful_shutdown);
let guard = shutdown.guard_weak();
loop {
tokio::select! {
res = listener.accept() => {
let (tcp, client_addr) = match res {
Ok((tcp, client_addr)) => (tcp, client_addr),
Err(e) => {
error!("Failed to accept incoming connection: {}", e);
continue;
}
};
let client = client.clone();
let server = server.clone();
let ca = Arc::clone(&self.ca);
let http_handler = self.http_handler.clone();
let websocket_handler = self.websocket_handler.clone();
let websocket_connector = self.websocket_connector.clone();
shutdown.spawn_task_fn(move |guard| async move {
let conn = server.serve_connection_with_upgrades(
TokioIo::new(tcp),
service_fn(|req| {
InternalProxy {
ca: Arc::clone(&ca),
client: client.clone(),
server: server.clone(),
http_handler: http_handler.clone(),
websocket_handler: websocket_handler.clone(),
websocket_connector: websocket_connector.clone(),
client_addr,
}
.proxy(req)
}),
);
let mut conn = std::pin::pin!(conn);
if let Err(err) = tokio::select! {
conn = conn.as_mut() => conn,
_ = guard.cancelled() => {
conn.as_mut().graceful_shutdown();
conn.await
}
} {
error!("Error serving connection: {}", err);
}
});
}
_ = guard.cancelled() => {
break;
}
}
}
shutdown.shutdown().await;
Ok(())
}
}