1use hyper::server::conn::http1;
2use hyper::service::service_fn;
3use hyper_rustls::HttpsConnectorBuilder;
4use hyper_util::client::legacy::Client;
5use hyper_util::rt::TokioIo;
6use std::sync::Arc;
7use std::sync::atomic::{AtomicUsize, Ordering};
8use std::time::{Duration, Instant};
9use tokio::sync::{mpsc::Sender, watch};
10use tracing::{error, info};
11use uuid::Uuid;
12
13use crate::capture::CaptureSource;
14use crate::capture::loop_detection::LoopDetector;
15use crate::interceptor::{ConnectAction, ConnectionInfo, ENGINE_INDEX, Interceptor};
16use crate::proxy::circuit_breaker::CircuitBreaker;
17use crate::proxy::http::handle_request;
18use crate::proxy::http_utils::HttpsClient;
19use crate::proxy::outbound::{
20 DirectConnector, HttpUpstreamConnector, HttpsUpstreamConnector, OutboundConnector,
21};
22use crate::rule::engine::executor::{ConnectOverride, RuleEngine};
23use crate::tls::CertificateAuthority;
24use chrono::Utc;
25use relay_core_api::flow::{Flow, FlowUpdate, Layer, NetworkInfo, TcpLayer, TransportProtocol};
26use relay_core_api::policy::ProxyPolicy;
27use relay_core_api::rule::RuleStage;
28use std::collections::HashMap;
29use std::net::{IpAddr, SocketAddr};
30use url::Url;
31
32static CONN_COUNTER: AtomicUsize = AtomicUsize::new(0);
33
34#[allow(clippy::too_many_arguments)]
36pub async fn start_proxy<S>(
37 mut source: S,
38 on_flow: Sender<FlowUpdate>,
39 interceptor: Arc<dyn Interceptor>,
40 ca: Arc<CertificateAuthority>,
41 policy: watch::Receiver<ProxyPolicy>,
42 client: Option<Arc<HttpsClient>>,
43 shutdown_rx: Option<tokio::sync::oneshot::Receiver<()>>,
44 rule_engine: Option<Arc<RuleEngine>>,
45) -> crate::error::Result<()>
46where
47 S: CaptureSource + Send + 'static,
48 S::IO: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin + Send + 'static,
49{
50 tracing::debug!("RelayCore proxy engine starting");
51 info!("CA Loaded. Root cert:\n{}", ca.get_ca_cert_pem());
52 let startup_policy = policy.borrow().clone();
53 info!("Proxy Policy: {:?}", startup_policy);
54
55 let client = if let Some(c) = client {
57 c
58 } else {
59 let https = HttpsConnectorBuilder::new()
60 .with_native_roots()?
61 .https_or_http()
62 .enable_http1()
63 .enable_http2()
64 .build();
65 let client: HttpsClient = Client::builder(hyper_util::rt::TokioExecutor::new())
66 .timer(hyper_util::rt::TokioTimer::new())
67 .pool_idle_timeout(Duration::from_secs(60))
68 .pool_max_idle_per_host(32)
69 .http2_initial_stream_window_size(2 * 1024 * 1024)
70 .http2_initial_connection_window_size(4 * 1024 * 1024)
71 .http2_keep_alive_interval(Duration::from_secs(20))
72 .http2_keep_alive_timeout(Duration::from_secs(10))
73 .build(https);
74 Arc::new(client)
75 };
76
77 let connector: Arc<dyn OutboundConnector> = match &startup_policy.upstream {
79 Some(upstream) => {
80 let scheme = Url::parse(&upstream.proxy_url)
81 .map(|u| u.scheme().to_string())
82 .unwrap_or_else(|_| "http".to_string());
83 let result = if scheme == "https" {
84 HttpsUpstreamConnector::new(upstream)
85 .await
86 .map(|c| Arc::new(c) as Arc<dyn OutboundConnector>)
87 } else {
88 HttpUpstreamConnector::new(upstream)
89 .await
90 .map(|c| Arc::new(c) as Arc<dyn OutboundConnector>)
91 };
92 match result {
93 Ok(c) => c,
94 Err(e) => {
95 if upstream.fail_open {
96 tracing::warn!(
97 "Failed to create upstream connector: {:?}, falling back to direct (fail_open=true)",
98 e
99 );
100 Arc::new(DirectConnector::new(client.clone()))
101 } else {
102 tracing::error!(
103 "Failed to create upstream connector: {:?}, aborting startup (fail_open=false)",
104 e
105 );
106 return Err(crate::error::RelayError::Proxy(format!(
107 "upstream proxy configuration failed: {}",
108 e
109 )));
110 }
111 }
112 }
113 }
114 None => Arc::new(DirectConnector::new(client.clone())),
115 };
116
117 let listen_addrs = source.listen_addrs().into_iter().collect();
119 let loop_detector = Arc::new(LoopDetector::new(listen_addrs));
120 {
121 let loop_detector_bg = loop_detector.clone();
122 tokio::spawn(async move {
123 loop_detector_bg.refresh_local_addrs().await;
125 let mut ticker = tokio::time::interval(Duration::from_secs(60));
126 loop {
127 ticker.tick().await;
128 loop_detector_bg.refresh_local_addrs().await;
129 }
130 });
131 }
132
133 let circuit_breaker = Arc::new(CircuitBreaker::default());
135
136 let mut shutdown_rx = shutdown_rx;
137
138 loop {
139 let connection_result = tokio::select! {
141 res = source.accept() => res,
142 _ = async {
143 if let Some(rx) = shutdown_rx.as_mut() {
144 rx.await.ok();
145 } else {
146 std::future::pending::<()>().await;
147 }
148 } => {
149 info!("RelayCore Proxy received shutdown signal");
150 break;
151 }
152 };
153
154 let connection = match connection_result {
155 Ok(conn) => conn,
156 Err(e) => {
157 error!("Error accepting connection: {}", e);
158 continue;
159 }
160 };
161
162 let stream = connection.stream;
163 let client_addr = connection.client_addr;
164 let target_addr = connection.target_addr;
165
166 let conn_id = Uuid::new_v4();
167 let conn_info = ConnectionInfo {
168 id: conn_id,
169 client_addr,
170 server_addr: target_addr,
171 tls_sni: None,
173 };
174
175 match interceptor.on_connect(&conn_info).await {
176 ConnectAction::Drop { reason } => {
177 info!("Connection {} dropped by onConnect: {}", conn_id, reason);
178 interceptor
179 .on_disconnect(&conn_info, &Default::default())
180 .await;
181 continue;
182 }
183 ConnectAction::Allow => {}
184 }
185
186 let mut connect_target = target_addr;
187 if let Some(ref engine) = rule_engine
188 && engine.has_rules_for_stage(RuleStage::Connect)
189 {
190 let mut flow = Flow {
191 id: conn_id,
192 start_time: Utc::now(),
193 end_time: None,
194 network: NetworkInfo {
195 client_ip: client_addr.ip().to_string(),
196 client_port: client_addr.port(),
197 server_ip: target_addr.map(|a| a.ip().to_string()).unwrap_or_default(),
198 server_port: target_addr.map(|a| a.port()).unwrap_or(0),
199 protocol: TransportProtocol::TCP,
200 tls: false,
201 tls_version: None,
202 sni: None,
203 },
204 layer: Layer::Tcp(TcpLayer {
205 bytes_up: 0,
206 bytes_down: 0,
207 }),
208 tags: vec![],
209 meta: HashMap::new(),
210 resilience_trace: None,
211 rule_variables: HashMap::new(),
212 matched_rules: vec![],
213 };
214 let ctx = engine.execute(RuleStage::Connect, &mut flow).await;
215 if ctx.is_terminated() {
216 info!("Connection {} terminated by Connect stage rule", conn_id);
217 interceptor
218 .on_disconnect(&conn_info, &Default::default())
219 .await;
220 continue;
221 }
222 if let Some(conn_override) = &ctx.connect_override {
223 match conn_override {
224 ConnectOverride::ForwardPort { host: _, port } => {
225 connect_target = Some(SocketAddr::new(
226 target_addr
227 .map(|a| a.ip())
228 .unwrap_or(IpAddr::from([0, 0, 0, 0])),
229 *port,
230 ));
231 tracing::debug!("Connect stage ForwardPort -> port {}", port);
232 }
233 ConnectOverride::RedirectIp { ip } => {
234 let port = connect_target.map(|a| a.port()).unwrap_or(0);
235 connect_target = Some(SocketAddr::new(*ip, port));
236 tracing::debug!("Connect stage RedirectIp -> {}", ip);
237 }
238 ConnectOverride::SetTtl { ttl } => {
239 tracing::warn!(
244 "Connect stage SetTtl({}) is not yet implemented; TTL unchanged",
245 ttl
246 );
247 }
248 }
249 }
250 }
251
252 let target_addr = connect_target;
253
254 let io = TokioIo::new(stream);
255 let on_flow = on_flow.clone();
256 let ca = ca.clone();
257 let connector = connector.clone();
258 let interceptor = interceptor.clone();
259 let policy = policy.clone();
260 let loop_detector = loop_detector.clone();
261
262 let circuit_breaker = circuit_breaker.clone();
263
264 let engine_index = CONN_COUNTER.fetch_add(1, Ordering::Relaxed);
265
266 let conn_info_2 = conn_info.clone();
267 let interceptor_2 = interceptor.clone();
268 tokio::task::spawn(ENGINE_INDEX.scope(engine_index, async move {
269 let conn_start = Instant::now();
270 let result = http1::Builder::new()
271 .timer(hyper_util::rt::TokioTimer::new())
272 .header_read_timeout(Duration::from_secs(10))
273 .preserve_header_case(true)
274 .title_case_headers(true)
275 .serve_connection(
276 io,
277 service_fn(move |req| {
278 handle_request(
279 req,
280 client_addr,
281 on_flow.clone(),
282 ca.clone(),
283 connector.clone(),
284 interceptor.clone(),
285 target_addr,
286 policy.clone(),
287 loop_detector.clone(),
288 circuit_breaker.clone(),
289 )
290 }),
291 )
292 .with_upgrades()
293 .await;
294
295 let stats = crate::interceptor::ConnectionStats {
296 duration_ms: conn_start.elapsed().as_millis() as u64,
297 ..Default::default()
299 };
300 interceptor_2.on_disconnect(&conn_info_2, &stats).await;
301
302 if let Err(err) = result {
303 error!("Error serving connection: {:?}", err);
304 }
305 }));
306 }
307
308 Ok(())
309}