1use hyper::server::conn::http1;
2use hyper::service::service_fn;
3use hyper_rustls::HttpsConnectorBuilder;
4use hyper_util::client::legacy::Client;
5use hyper_util::rt::TokioIo;
6use std::sync::Arc;
7use std::sync::atomic::{AtomicUsize, Ordering};
8use std::time::{Duration, Instant};
9use tokio::sync::{mpsc::Sender, watch};
10use tracing::{error, info};
11use uuid::Uuid;
12
13use crate::capture::CaptureSource;
14use crate::capture::loop_detection::LoopDetector;
15use crate::interceptor::{ConnectAction, ConnectionInfo, ENGINE_INDEX, Interceptor};
16use crate::proxy::circuit_breaker::CircuitBreaker;
17use crate::proxy::http::handle_request;
18use crate::proxy::http_utils::HttpsClient;
19use crate::rule::engine::executor::{ConnectOverride, RuleEngine};
20use crate::tls::CertificateAuthority;
21use chrono::Utc;
22use relay_core_api::flow::{Flow, FlowUpdate, Layer, NetworkInfo, TcpLayer, TransportProtocol};
23use relay_core_api::policy::ProxyPolicy;
24use relay_core_api::rule::RuleStage;
25use std::collections::HashMap;
26use std::net::{IpAddr, SocketAddr};
27
28static CONN_COUNTER: AtomicUsize = AtomicUsize::new(0);
29
30#[allow(clippy::too_many_arguments)]
32pub async fn start_proxy<S>(
33 mut source: S,
34 on_flow: Sender<FlowUpdate>,
35 interceptor: Arc<dyn Interceptor>,
36 ca: Arc<CertificateAuthority>,
37 policy: watch::Receiver<ProxyPolicy>,
38 client: Option<Arc<HttpsClient>>,
39 shutdown_rx: Option<tokio::sync::oneshot::Receiver<()>>,
40 rule_engine: Option<Arc<RuleEngine>>,
41) -> crate::error::Result<()>
42where
43 S: CaptureSource + Send + 'static,
44 S::IO: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin + Send + 'static,
45{
46 info!("RelayCore Proxy starting...");
47 info!("CA Loaded. Root cert:\n{}", ca.get_ca_cert_pem());
48 info!("Proxy Policy: {:?}", policy.borrow());
49
50 let client = if let Some(c) = client {
52 c
53 } else {
54 let https = HttpsConnectorBuilder::new()
55 .with_native_roots()?
56 .https_or_http()
57 .enable_http1()
58 .enable_http2()
59 .build();
60 let client: HttpsClient = Client::builder(hyper_util::rt::TokioExecutor::new())
61 .timer(hyper_util::rt::TokioTimer::new())
62 .pool_idle_timeout(Duration::from_secs(60))
63 .pool_max_idle_per_host(32)
64 .http2_initial_stream_window_size(2 * 1024 * 1024) .http2_initial_connection_window_size(4 * 1024 * 1024) .http2_keep_alive_interval(Duration::from_secs(20))
67 .http2_keep_alive_timeout(Duration::from_secs(10))
68 .build(https);
69 Arc::new(client)
70 };
71
72 let listen_addrs = source.listen_addrs().into_iter().collect();
74 let loop_detector = Arc::new(LoopDetector::new(listen_addrs));
75 {
76 let loop_detector_bg = loop_detector.clone();
77 tokio::spawn(async move {
78 loop_detector_bg.refresh_local_addrs().await;
80 let mut ticker = tokio::time::interval(Duration::from_secs(60));
81 loop {
82 ticker.tick().await;
83 loop_detector_bg.refresh_local_addrs().await;
84 }
85 });
86 }
87
88 let circuit_breaker = Arc::new(CircuitBreaker::default());
90
91 let mut shutdown_rx = shutdown_rx;
92
93 loop {
94 let connection_result = tokio::select! {
96 res = source.accept() => res,
97 _ = async {
98 if let Some(rx) = shutdown_rx.as_mut() {
99 rx.await.ok();
100 } else {
101 std::future::pending::<()>().await;
102 }
103 } => {
104 info!("RelayCore Proxy received shutdown signal");
105 break;
106 }
107 };
108
109 let connection = match connection_result {
110 Ok(conn) => conn,
111 Err(e) => {
112 error!("Error accepting connection: {}", e);
113 continue;
114 }
115 };
116
117 let stream = connection.stream;
118 let client_addr = connection.client_addr;
119 let target_addr = connection.target_addr;
120
121 let conn_id = Uuid::new_v4();
122 let conn_info = ConnectionInfo {
123 id: conn_id,
124 client_addr,
125 server_addr: target_addr,
126 tls_sni: None,
128 };
129
130 match interceptor.on_connect(&conn_info).await {
131 ConnectAction::Drop { reason } => {
132 info!("Connection {} dropped by onConnect: {}", conn_id, reason);
133 interceptor
134 .on_disconnect(&conn_info, &Default::default())
135 .await;
136 continue;
137 }
138 ConnectAction::Allow => {}
139 }
140
141 let mut connect_target = target_addr;
142 if let Some(ref engine) = rule_engine
143 && engine.has_rules_for_stage(RuleStage::Connect)
144 {
145 let mut flow = Flow {
146 id: conn_id,
147 start_time: Utc::now(),
148 end_time: None,
149 network: NetworkInfo {
150 client_ip: client_addr.ip().to_string(),
151 client_port: client_addr.port(),
152 server_ip: target_addr.map(|a| a.ip().to_string()).unwrap_or_default(),
153 server_port: target_addr.map(|a| a.port()).unwrap_or(0),
154 protocol: TransportProtocol::TCP,
155 tls: false,
156 tls_version: None,
157 sni: None,
158 },
159 layer: Layer::Tcp(TcpLayer {
160 bytes_up: 0,
161 bytes_down: 0,
162 }),
163 tags: vec![],
164 meta: HashMap::new(),
165 resilience_trace: None,
166 rule_variables: HashMap::new(),
167 matched_rules: vec![],
168 };
169 let ctx = engine.execute(RuleStage::Connect, &mut flow).await;
170 if ctx.is_terminated() {
171 info!("Connection {} terminated by Connect stage rule", conn_id);
172 interceptor
173 .on_disconnect(&conn_info, &Default::default())
174 .await;
175 continue;
176 }
177 if let Some(conn_override) = &ctx.connect_override {
178 match conn_override {
179 ConnectOverride::ForwardPort { host: _, port } => {
180 connect_target = Some(SocketAddr::new(
181 target_addr
182 .map(|a| a.ip())
183 .unwrap_or(IpAddr::from([0, 0, 0, 0])),
184 *port,
185 ));
186 tracing::debug!("Connect stage ForwardPort -> port {}", port);
187 }
188 ConnectOverride::RedirectIp { ip } => {
189 let port = connect_target.map(|a| a.port()).unwrap_or(0);
190 connect_target = Some(SocketAddr::new(*ip, port));
191 tracing::debug!("Connect stage RedirectIp -> {}", ip);
192 }
193 ConnectOverride::SetTtl { ttl } => {
194 tracing::warn!(
199 "Connect stage SetTtl({}) is not yet implemented; TTL unchanged",
200 ttl
201 );
202 }
203 }
204 }
205 }
206
207 let target_addr = connect_target;
208
209 let io = TokioIo::new(stream);
210 let on_flow = on_flow.clone();
211 let ca = ca.clone();
212 let client = client.clone();
213 let interceptor = interceptor.clone();
214 let policy = policy.clone();
215 let loop_detector = loop_detector.clone();
216
217 let circuit_breaker = circuit_breaker.clone();
218
219 let engine_index = CONN_COUNTER.fetch_add(1, Ordering::Relaxed);
220
221 let conn_info_2 = conn_info.clone();
222 let interceptor_2 = interceptor.clone();
223 tokio::task::spawn(ENGINE_INDEX.scope(engine_index, async move {
224 let conn_start = Instant::now();
225 let result = http1::Builder::new()
226 .timer(hyper_util::rt::TokioTimer::new())
227 .header_read_timeout(Duration::from_secs(10))
228 .preserve_header_case(true)
229 .title_case_headers(true)
230 .serve_connection(
231 io,
232 service_fn(move |req| {
233 handle_request(
234 req,
235 client_addr,
236 on_flow.clone(),
237 ca.clone(),
238 client.clone(),
239 interceptor.clone(),
240 target_addr,
241 policy.clone(),
242 loop_detector.clone(),
243 circuit_breaker.clone(),
244 )
245 }),
246 )
247 .with_upgrades()
248 .await;
249
250 let stats = crate::interceptor::ConnectionStats {
251 duration_ms: conn_start.elapsed().as_millis() as u64,
252 ..Default::default()
254 };
255 interceptor_2.on_disconnect(&conn_info_2, &stats).await;
256
257 if let Err(err) = result {
258 error!("Error serving connection: {:?}", err);
259 }
260 }));
261 }
262
263 Ok(())
264}