pub(crate) mod forward;
pub(crate) mod reverse;
use std::{future::Future, net::SocketAddr, path::PathBuf, sync::Arc};
use hyper::Uri;
use hyper_util::client::legacy::connect::HttpConnector;
use hyper_util::rt::TokioExecutor;
use proxyapi_models::ProxiedRequest;
use tokio::net::TcpListener;
use tokio::sync::mpsc;
use crate::body::ProxyBody;
use crate::ca::Ssl;
use crate::error::Error;
use crate::event::ProxyEvent;
use crate::handler::CapturingHandler;
use crate::intercept::InterceptConfig;
#[cfg(feature = "scripting")]
use crate::scripting::ScriptEngine;
pub(crate) type Client =
hyper_util::client::legacy::Client<hyper_rustls::HttpsConnector<HttpConnector>, ProxyBody>;
pub(crate) fn is_benign_shutdown_error(e: &dyn std::error::Error) -> bool {
let msg = e.to_string();
msg.contains("shutting down") || msg.contains("connection was not closed cleanly")
}
pub struct ProxyConfig {
pub addr: SocketAddr,
pub mode: ProxyMode,
pub event_tx: mpsc::Sender<ProxyEvent>,
pub ca_dir: PathBuf,
pub intercept: Option<Arc<InterceptConfig>>,
#[cfg(feature = "scripting")]
pub script_path: Option<PathBuf>,
pub replay_rx: Option<mpsc::Receiver<ProxiedRequest>>,
}
#[derive(Debug, Clone)]
pub enum ProxyMode {
Forward,
Reverse {
target: Uri,
},
}
pub struct Proxy {
config: ProxyConfig,
}
impl Proxy {
pub const fn new(config: ProxyConfig) -> Self {
Self { config }
}
pub async fn start(self, shutdown: impl Future<Output = ()>) -> Result<(), Error> {
let listener = TcpListener::bind(self.config.addr).await?;
tracing::info!("Proxy listening on {}", self.config.addr);
#[cfg(feature = "scripting")]
let script_engine: Option<Arc<ScriptEngine>> = self
.config
.script_path
.as_ref()
.map(|p| {
tracing::info!("Loading Lua script: {}", p.display());
ScriptEngine::new(p).map(Arc::new)
})
.transpose()?;
let ca_dir = self.config.ca_dir.clone();
let ca =
Arc::new(tokio::task::spawn_blocking(move || Ssl::load_or_generate(&ca_dir)).await??);
let https = hyper_rustls::HttpsConnectorBuilder::new()
.with_webpki_roots()
.https_or_http()
.enable_http1()
.build();
let client = Arc::new(
hyper_util::client::legacy::Client::builder(TokioExecutor::new()).build(https),
);
let mut replay_rx = self.config.replay_rx;
tokio::pin!(shutdown);
loop {
tokio::select! {
result = listener.accept() => {
let (stream, remote_addr) = match result {
Ok(conn) => conn,
Err(e) => {
tracing::warn!("Failed to accept connection: {e}");
continue;
}
};
let mut handler = CapturingHandler::new(self.config.event_tx.clone());
if let Some(ref ic) = self.config.intercept {
handler = handler.with_intercept(Arc::clone(ic));
}
#[cfg(feature = "scripting")]
if let Some(ref engine) = script_engine {
handler = handler.with_script_engine(Arc::clone(engine));
}
let ca = Arc::clone(&ca);
let client = Arc::clone(&client);
match &self.config.mode {
ProxyMode::Forward => {
let listen_addr = self.config.addr;
tokio::spawn(forward::handle_connection(
stream, remote_addr, handler, ca, client, listen_addr,
));
}
ProxyMode::Reverse { target } => {
let target = target.clone();
tokio::spawn(reverse::handle_connection(
stream, remote_addr, handler, target, client,
));
}
}
}
Some(req) = recv_replay(&mut replay_rx) => {
let mut handler = CapturingHandler::new(self.config.event_tx.clone());
if let Some(ref ic) = self.config.intercept {
handler = handler.with_intercept(Arc::clone(ic));
}
#[cfg(feature = "scripting")]
if let Some(ref engine) = script_engine {
handler = handler.with_script_engine(Arc::clone(engine));
}
tokio::spawn(forward::handle_replay(req, handler, Arc::clone(&client)));
}
() = &mut shutdown => {
tracing::info!("Proxy shutting down");
break;
}
}
}
Ok(())
}
}
async fn recv_replay(rx: &mut Option<mpsc::Receiver<ProxiedRequest>>) -> Option<ProxiedRequest> {
match rx {
Some(rx) => rx.recv().await,
None => std::future::pending().await,
}
}