mod internal;
pub mod builder;
pub use builder::ProxyBuilder;
use crate::mitm::{certificate_authority::CertificateAuthority, Error, HttpHandler, WebSocketHandler};
use builder::AddrOrListener;
use hyper_util::{
client::legacy::{
connect::{Connect, HttpConnector},
Client,
},
rt::{TokioExecutor, TokioIo},
server::conn::auto,
server::conn::auto::Builder,
};
use internal::InternalProxy;
use std::{future::Future, sync::Arc};
use std::future::Pending;
use tokio::net::TcpListener;
use tokio_graceful::Shutdown;
use tokio_tungstenite::Connector;
use hyper::service::service_fn;
use hyper_rustls::HttpsConnector;
use product_os_http_body::{BodyBytes as Body, BodyBytes};
#[cfg(feature = "tor")]
use tor_rtcompat::PreferredRuntime;
use crate::mitm::certificate_authority::RcgenAuthority;
use crate::ProxyMiddleware;
pub struct Proxy<C, CA, H, W, F> {
address_or_listener: AddrOrListener,
ca: Arc<CA>,
client: Client<C, BodyBytes>,
websocket_connector: Option<Connector>,
http_handler: H,
websocket_handler: W,
certificates: product_os_security::certificates::Certificates,
custom_requester: Option<product_os_request::ProductOSRequestClient>,
compression: product_os_configuration::NetworkProxyCompression,
server: Option<Builder<TokioExecutor>>,
graceful_shutdown: F,
#[cfg(feature = "tor")]
tor_client: Option<arti_client::TorClient<tor_rtcompat::PreferredRuntime>>,
#[cfg(feature = "vpn")]
vpn_client: Option<product_os_vpn::ProductOSVPN>
}
impl Proxy<(), RcgenAuthority, ProxyMiddleware, ProxyMiddleware, Pending<()>> {
pub fn https_builder() -> ProxyBuilder<HttpsConnector<HttpConnector>, RcgenAuthority, ProxyMiddleware, ProxyMiddleware, Pending<()>> {
ProxyBuilder::new()
}
pub fn http_builder() -> ProxyBuilder<HttpConnector, RcgenAuthority, ProxyMiddleware, ProxyMiddleware, Pending<()>> {
ProxyBuilder::new()
}
}
impl<C, CA, H, W, F> Proxy<C, CA, H, W, F>
where
C: Connect + Clone + Send + Sync + 'static,
CA: CertificateAuthority,
H: HttpHandler + Clone,
W: WebSocketHandler + Clone,
F: Future<Output = ()> + Send + 'static,
{
pub async fn start(self) -> Result<(), Error> {
let listener = match self.address_or_listener {
AddrOrListener::Addr(addr) => TcpListener::bind(addr).await?,
AddrOrListener::Listener(listener) => listener,
};
let ca = Arc::clone(&self.ca);
let client = self.client.clone();
let websocket_connector = self.websocket_connector.clone();
let http_handler = self.http_handler.clone();
let websocket_handler = self.websocket_handler.clone();
let shutdown = Shutdown::new(self.graceful_shutdown);
let guard = shutdown.guard_weak();
#[cfg(feature = "tor")]
let tor_client = self.tor_client.clone();
#[cfg(feature = "vpn")]
let vpn_client = self.vpn_client.clone();
let server = self.server.unwrap_or_else(|| {
let mut builder = Builder::new(TokioExecutor::new());
builder
.http1()
.title_case_headers(true)
.preserve_header_case(true)
.http2();
builder
});
loop {
let custom_requester = match &self.custom_requester {
Some(custom_requester) => Some(custom_requester.clone()),
None => None
};
let compression = self.compression.clone();
tokio::select! {
res = listener.accept() => {
let (tcp, client_addr) = match res {
Ok((tcp, client_addr)) => (tcp, client_addr),
Err(e) => {
tracing::error!("Failed to accept incoming connection: {}", e);
continue;
}
};
let server = server.clone();
let client = client.clone();
let ca = Arc::clone(&ca);
let http_handler = http_handler.clone();
let websocket_handler = websocket_handler.clone();
let websocket_connector = websocket_connector.clone();
#[cfg(feature = "tor")]
let tor_client = tor_client.clone();
#[cfg(feature = "vpn")]
let vpn_client = vpn_client.clone();
shutdown.spawn_task_fn(move |guard| async move {
let conn = server.serve_connection_with_upgrades(
TokioIo::new(tcp),
service_fn(|req| {
InternalProxy {
ca: Arc::clone(&ca),
client: client.clone(),
http_handler: http_handler.clone(),
websocket_handler: websocket_handler.clone(),
websocket_connector: websocket_connector.clone(),
server: server.clone(),
client_addr,
#[cfg(feature = "tor")]
tor_client: tor_client.clone(),
#[cfg(feature = "vpn")]
vpn_client: vpn_client.clone(),
custom_requester: custom_requester.clone(),
compression: compression.clone()
}
.proxy(req)
}),
);
let mut conn = std::pin::pin!(conn);
if let Err(err) = tokio::select! {
conn = conn.as_mut() => conn,
_ = guard.cancelled() => {
conn.as_mut().graceful_shutdown();
conn.await
}
} {
tracing::error!("Error serving connection: {}", err);
}
});
}
_ = guard.cancelled() => {
break;
}
}
}
Ok(())
}
}