use hyper::server::conn::http1;
use hyper::service::service_fn;
use hyper_rustls::HttpsConnectorBuilder;
use hyper_util::client::legacy::Client;
use hyper_util::rt::TokioIo;
use std::sync::Arc;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::time::{Duration, Instant};
use tokio::sync::{mpsc::Sender, watch};
use tracing::{error, info};
use uuid::Uuid;
use crate::capture::CaptureSource;
use crate::capture::loop_detection::LoopDetector;
use crate::interceptor::{ConnectAction, ConnectionInfo, ENGINE_INDEX, Interceptor};
use crate::proxy::circuit_breaker::CircuitBreaker;
use crate::proxy::http::handle_request;
use crate::proxy::http_utils::HttpsClient;
use crate::proxy::outbound::{
DirectConnector, HttpUpstreamConnector, HttpsUpstreamConnector, OutboundConnector,
};
use crate::rule::engine::executor::{ConnectOverride, RuleEngine};
use crate::tls::CertificateAuthority;
use chrono::Utc;
use relay_core_api::flow::{Flow, FlowUpdate, Layer, NetworkInfo, TcpLayer, TransportProtocol};
use relay_core_api::policy::ProxyPolicy;
use relay_core_api::rule::RuleStage;
use std::collections::HashMap;
use std::net::{IpAddr, SocketAddr};
use url::Url;
static CONN_COUNTER: AtomicUsize = AtomicUsize::new(0);
#[allow(clippy::too_many_arguments)]
pub async fn start_proxy<S>(
mut source: S,
on_flow: Sender<FlowUpdate>,
interceptor: Arc<dyn Interceptor>,
ca: Arc<CertificateAuthority>,
policy: watch::Receiver<ProxyPolicy>,
client: Option<Arc<HttpsClient>>,
shutdown_rx: Option<tokio::sync::oneshot::Receiver<()>>,
rule_engine: Option<Arc<RuleEngine>>,
) -> crate::error::Result<()>
where
S: CaptureSource + Send + 'static,
S::IO: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin + Send + 'static,
{
info!("RelayCore Proxy starting...");
info!("CA Loaded. Root cert:\n{}", ca.get_ca_cert_pem());
let startup_policy = policy.borrow().clone();
info!("Proxy Policy: {:?}", startup_policy);
let client = if let Some(c) = client {
c
} else {
let https = HttpsConnectorBuilder::new()
.with_native_roots()?
.https_or_http()
.enable_http1()
.enable_http2()
.build();
let client: HttpsClient = Client::builder(hyper_util::rt::TokioExecutor::new())
.timer(hyper_util::rt::TokioTimer::new())
.pool_idle_timeout(Duration::from_secs(60))
.pool_max_idle_per_host(32)
.http2_initial_stream_window_size(2 * 1024 * 1024)
.http2_initial_connection_window_size(4 * 1024 * 1024)
.http2_keep_alive_interval(Duration::from_secs(20))
.http2_keep_alive_timeout(Duration::from_secs(10))
.build(https);
Arc::new(client)
};
let connector: Arc<dyn OutboundConnector> = match &startup_policy.upstream {
Some(upstream) => {
let scheme = Url::parse(&upstream.proxy_url)
.map(|u| u.scheme().to_string())
.unwrap_or_else(|_| "http".to_string());
let result = if scheme == "https" {
HttpsUpstreamConnector::new(upstream)
.await
.map(|c| Arc::new(c) as Arc<dyn OutboundConnector>)
} else {
HttpUpstreamConnector::new(upstream)
.await
.map(|c| Arc::new(c) as Arc<dyn OutboundConnector>)
};
match result {
Ok(c) => c,
Err(e) => {
if upstream.fail_open {
tracing::warn!(
"Failed to create upstream connector: {:?}, falling back to direct (fail_open=true)",
e
);
Arc::new(DirectConnector::new(client.clone()))
} else {
tracing::error!(
"Failed to create upstream connector: {:?}, aborting startup (fail_open=false)",
e
);
return Err(crate::error::RelayError::Proxy(format!(
"upstream proxy configuration failed: {}",
e
)));
}
}
}
}
None => Arc::new(DirectConnector::new(client.clone())),
};
let listen_addrs = source.listen_addrs().into_iter().collect();
let loop_detector = Arc::new(LoopDetector::new(listen_addrs));
{
let loop_detector_bg = loop_detector.clone();
tokio::spawn(async move {
loop_detector_bg.refresh_local_addrs().await;
let mut ticker = tokio::time::interval(Duration::from_secs(60));
loop {
ticker.tick().await;
loop_detector_bg.refresh_local_addrs().await;
}
});
}
let circuit_breaker = Arc::new(CircuitBreaker::default());
let mut shutdown_rx = shutdown_rx;
loop {
let connection_result = tokio::select! {
res = source.accept() => res,
_ = async {
if let Some(rx) = shutdown_rx.as_mut() {
rx.await.ok();
} else {
std::future::pending::<()>().await;
}
} => {
info!("RelayCore Proxy received shutdown signal");
break;
}
};
let connection = match connection_result {
Ok(conn) => conn,
Err(e) => {
error!("Error accepting connection: {}", e);
continue;
}
};
let stream = connection.stream;
let client_addr = connection.client_addr;
let target_addr = connection.target_addr;
let conn_id = Uuid::new_v4();
let conn_info = ConnectionInfo {
id: conn_id,
client_addr,
server_addr: target_addr,
tls_sni: None,
};
match interceptor.on_connect(&conn_info).await {
ConnectAction::Drop { reason } => {
info!("Connection {} dropped by onConnect: {}", conn_id, reason);
interceptor
.on_disconnect(&conn_info, &Default::default())
.await;
continue;
}
ConnectAction::Allow => {}
}
let mut connect_target = target_addr;
if let Some(ref engine) = rule_engine
&& engine.has_rules_for_stage(RuleStage::Connect)
{
let mut flow = Flow {
id: conn_id,
start_time: Utc::now(),
end_time: None,
network: NetworkInfo {
client_ip: client_addr.ip().to_string(),
client_port: client_addr.port(),
server_ip: target_addr.map(|a| a.ip().to_string()).unwrap_or_default(),
server_port: target_addr.map(|a| a.port()).unwrap_or(0),
protocol: TransportProtocol::TCP,
tls: false,
tls_version: None,
sni: None,
},
layer: Layer::Tcp(TcpLayer {
bytes_up: 0,
bytes_down: 0,
}),
tags: vec![],
meta: HashMap::new(),
resilience_trace: None,
rule_variables: HashMap::new(),
matched_rules: vec![],
};
let ctx = engine.execute(RuleStage::Connect, &mut flow).await;
if ctx.is_terminated() {
info!("Connection {} terminated by Connect stage rule", conn_id);
interceptor
.on_disconnect(&conn_info, &Default::default())
.await;
continue;
}
if let Some(conn_override) = &ctx.connect_override {
match conn_override {
ConnectOverride::ForwardPort { host: _, port } => {
connect_target = Some(SocketAddr::new(
target_addr
.map(|a| a.ip())
.unwrap_or(IpAddr::from([0, 0, 0, 0])),
*port,
));
tracing::debug!("Connect stage ForwardPort -> port {}", port);
}
ConnectOverride::RedirectIp { ip } => {
let port = connect_target.map(|a| a.port()).unwrap_or(0);
connect_target = Some(SocketAddr::new(*ip, port));
tracing::debug!("Connect stage RedirectIp -> {}", ip);
}
ConnectOverride::SetTtl { ttl } => {
tracing::warn!(
"Connect stage SetTtl({}) is not yet implemented; TTL unchanged",
ttl
);
}
}
}
}
let target_addr = connect_target;
let io = TokioIo::new(stream);
let on_flow = on_flow.clone();
let ca = ca.clone();
let connector = connector.clone();
let interceptor = interceptor.clone();
let policy = policy.clone();
let loop_detector = loop_detector.clone();
let circuit_breaker = circuit_breaker.clone();
let engine_index = CONN_COUNTER.fetch_add(1, Ordering::Relaxed);
let conn_info_2 = conn_info.clone();
let interceptor_2 = interceptor.clone();
tokio::task::spawn(ENGINE_INDEX.scope(engine_index, async move {
let conn_start = Instant::now();
let result = http1::Builder::new()
.timer(hyper_util::rt::TokioTimer::new())
.header_read_timeout(Duration::from_secs(10))
.preserve_header_case(true)
.title_case_headers(true)
.serve_connection(
io,
service_fn(move |req| {
handle_request(
req,
client_addr,
on_flow.clone(),
ca.clone(),
connector.clone(),
interceptor.clone(),
target_addr,
policy.clone(),
loop_detector.clone(),
circuit_breaker.clone(),
)
}),
)
.with_upgrades()
.await;
let stats = crate::interceptor::ConnectionStats {
duration_ms: conn_start.elapsed().as_millis() as u64,
..Default::default()
};
interceptor_2.on_disconnect(&conn_info_2, &stats).await;
if let Err(err) = result {
error!("Error serving connection: {:?}", err);
}
}));
}
Ok(())
}