Skip to main content

relay_core_lib/proxy/
server.rs

1use hyper::server::conn::http1;
2use hyper::service::service_fn;
3use hyper_rustls::HttpsConnectorBuilder;
4use hyper_util::client::legacy::Client;
5use hyper_util::rt::TokioIo;
6use std::sync::Arc;
7use std::sync::atomic::{AtomicUsize, Ordering};
8use std::time::{Duration, Instant};
9use tokio::sync::{mpsc::Sender, watch};
10use tracing::{error, info};
11use uuid::Uuid;
12
13use crate::capture::CaptureSource;
14use crate::capture::loop_detection::LoopDetector;
15use crate::interceptor::{ConnectAction, ConnectionInfo, ENGINE_INDEX, Interceptor};
16use crate::proxy::circuit_breaker::CircuitBreaker;
17use crate::proxy::http::handle_request;
18use crate::proxy::http_utils::HttpsClient;
19use crate::proxy::outbound::{
20    DirectConnector, HttpUpstreamConnector, HttpsUpstreamConnector, OutboundConnector,
21};
22use crate::rule::engine::executor::{ConnectOverride, RuleEngine};
23use crate::tls::CertificateAuthority;
24use chrono::Utc;
25use relay_core_api::flow::{Flow, FlowUpdate, Layer, NetworkInfo, TcpLayer, TransportProtocol};
26use relay_core_api::policy::ProxyPolicy;
27use relay_core_api::rule::RuleStage;
28use std::collections::HashMap;
29use std::net::{IpAddr, SocketAddr};
30use url::Url;
31
32static CONN_COUNTER: AtomicUsize = AtomicUsize::new(0);
33
34/// Start the HTTP Proxy Server
35#[allow(clippy::too_many_arguments)]
36pub async fn start_proxy<S>(
37    mut source: S,
38    on_flow: Sender<FlowUpdate>,
39    interceptor: Arc<dyn Interceptor>,
40    ca: Arc<CertificateAuthority>,
41    policy: watch::Receiver<ProxyPolicy>,
42    client: Option<Arc<HttpsClient>>,
43    shutdown_rx: Option<tokio::sync::oneshot::Receiver<()>>,
44    rule_engine: Option<Arc<RuleEngine>>,
45) -> crate::error::Result<()>
46where
47    S: CaptureSource + Send + 'static,
48    S::IO: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin + Send + 'static,
49{
50    info!("RelayCore Proxy starting...");
51    info!("CA Loaded. Root cert:\n{}", ca.get_ca_cert_pem());
52    let startup_policy = policy.borrow().clone();
53    info!("Proxy Policy: {:?}", startup_policy);
54
55    // Initialize HTTP Client (shared) — used by DirectConnector
56    let client = if let Some(c) = client {
57        c
58    } else {
59        let https = HttpsConnectorBuilder::new()
60            .with_native_roots()?
61            .https_or_http()
62            .enable_http1()
63            .enable_http2()
64            .build();
65        let client: HttpsClient = Client::builder(hyper_util::rt::TokioExecutor::new())
66            .timer(hyper_util::rt::TokioTimer::new())
67            .pool_idle_timeout(Duration::from_secs(60))
68            .pool_max_idle_per_host(32)
69            .http2_initial_stream_window_size(2 * 1024 * 1024)
70            .http2_initial_connection_window_size(4 * 1024 * 1024)
71            .http2_keep_alive_interval(Duration::from_secs(20))
72            .http2_keep_alive_timeout(Duration::from_secs(10))
73            .build(https);
74        Arc::new(client)
75    };
76
77    // Initialize OutboundConnector based on ProxyPolicy.upstream
78    let connector: Arc<dyn OutboundConnector> = match &startup_policy.upstream {
79        Some(upstream) => {
80            let scheme = Url::parse(&upstream.proxy_url)
81                .map(|u| u.scheme().to_string())
82                .unwrap_or_else(|_| "http".to_string());
83            let result = if scheme == "https" {
84                HttpsUpstreamConnector::new(upstream)
85                    .await
86                    .map(|c| Arc::new(c) as Arc<dyn OutboundConnector>)
87            } else {
88                HttpUpstreamConnector::new(upstream)
89                    .await
90                    .map(|c| Arc::new(c) as Arc<dyn OutboundConnector>)
91            };
92            match result {
93                Ok(c) => c,
94                Err(e) => {
95                    if upstream.fail_open {
96                        tracing::warn!(
97                            "Failed to create upstream connector: {:?}, falling back to direct (fail_open=true)",
98                            e
99                        );
100                        Arc::new(DirectConnector::new(client.clone()))
101                    } else {
102                        tracing::error!(
103                            "Failed to create upstream connector: {:?}, aborting startup (fail_open=false)",
104                            e
105                        );
106                        return Err(crate::error::RelayError::Proxy(format!(
107                            "upstream proxy configuration failed: {}",
108                            e
109                        )));
110                    }
111                }
112            }
113        }
114        None => Arc::new(DirectConnector::new(client.clone())),
115    };
116
117    // Initialize Loop Detector
118    let listen_addrs = source.listen_addrs().into_iter().collect();
119    let loop_detector = Arc::new(LoopDetector::new(listen_addrs));
120    {
121        let loop_detector_bg = loop_detector.clone();
122        tokio::spawn(async move {
123            // Prime local interface cache at startup.
124            loop_detector_bg.refresh_local_addrs().await;
125            let mut ticker = tokio::time::interval(Duration::from_secs(60));
126            loop {
127                ticker.tick().await;
128                loop_detector_bg.refresh_local_addrs().await;
129            }
130        });
131    }
132
133    // Initialize Circuit Breaker (P3)
134    let circuit_breaker = Arc::new(CircuitBreaker::default());
135
136    let mut shutdown_rx = shutdown_rx;
137
138    loop {
139        // Accept connection from abstract source or handle shutdown
140        let connection_result = tokio::select! {
141            res = source.accept() => res,
142            _ = async {
143                if let Some(rx) = shutdown_rx.as_mut() {
144                    rx.await.ok();
145                } else {
146                    std::future::pending::<()>().await;
147                }
148            } => {
149                info!("RelayCore Proxy received shutdown signal");
150                break;
151            }
152        };
153
154        let connection = match connection_result {
155            Ok(conn) => conn,
156            Err(e) => {
157                error!("Error accepting connection: {}", e);
158                continue;
159            }
160        };
161
162        let stream = connection.stream;
163        let client_addr = connection.client_addr;
164        let target_addr = connection.target_addr;
165
166        let conn_id = Uuid::new_v4();
167        let conn_info = ConnectionInfo {
168            id: conn_id,
169            client_addr,
170            server_addr: target_addr,
171            // TODO: extract SNI from TLS ClientHello (requires pre-rustls intercept)
172            tls_sni: None,
173        };
174
175        match interceptor.on_connect(&conn_info).await {
176            ConnectAction::Drop { reason } => {
177                info!("Connection {} dropped by onConnect: {}", conn_id, reason);
178                interceptor
179                    .on_disconnect(&conn_info, &Default::default())
180                    .await;
181                continue;
182            }
183            ConnectAction::Allow => {}
184        }
185
186        let mut connect_target = target_addr;
187        if let Some(ref engine) = rule_engine
188            && engine.has_rules_for_stage(RuleStage::Connect)
189        {
190            let mut flow = Flow {
191                id: conn_id,
192                start_time: Utc::now(),
193                end_time: None,
194                network: NetworkInfo {
195                    client_ip: client_addr.ip().to_string(),
196                    client_port: client_addr.port(),
197                    server_ip: target_addr.map(|a| a.ip().to_string()).unwrap_or_default(),
198                    server_port: target_addr.map(|a| a.port()).unwrap_or(0),
199                    protocol: TransportProtocol::TCP,
200                    tls: false,
201                    tls_version: None,
202                    sni: None,
203                },
204                layer: Layer::Tcp(TcpLayer {
205                    bytes_up: 0,
206                    bytes_down: 0,
207                }),
208                tags: vec![],
209                meta: HashMap::new(),
210                resilience_trace: None,
211                rule_variables: HashMap::new(),
212                matched_rules: vec![],
213            };
214            let ctx = engine.execute(RuleStage::Connect, &mut flow).await;
215            if ctx.is_terminated() {
216                info!("Connection {} terminated by Connect stage rule", conn_id);
217                interceptor
218                    .on_disconnect(&conn_info, &Default::default())
219                    .await;
220                continue;
221            }
222            if let Some(conn_override) = &ctx.connect_override {
223                match conn_override {
224                    ConnectOverride::ForwardPort { host: _, port } => {
225                        connect_target = Some(SocketAddr::new(
226                            target_addr
227                                .map(|a| a.ip())
228                                .unwrap_or(IpAddr::from([0, 0, 0, 0])),
229                            *port,
230                        ));
231                        tracing::debug!("Connect stage ForwardPort -> port {}", port);
232                    }
233                    ConnectOverride::RedirectIp { ip } => {
234                        let port = connect_target.map(|a| a.port()).unwrap_or(0);
235                        connect_target = Some(SocketAddr::new(*ip, port));
236                        tracing::debug!("Connect stage RedirectIp -> {}", ip);
237                    }
238                    ConnectOverride::SetTtl { ttl } => {
239                        // Socket-level TTL application requires access to the raw
240                        // TcpStream before it enters the hyper HTTP stack, which is
241                        // not feasible with the current connection model.
242                        // Deferred to 1.x.
243                        tracing::warn!(
244                            "Connect stage SetTtl({}) is not yet implemented; TTL unchanged",
245                            ttl
246                        );
247                    }
248                }
249            }
250        }
251
252        let target_addr = connect_target;
253
254        let io = TokioIo::new(stream);
255        let on_flow = on_flow.clone();
256        let ca = ca.clone();
257        let connector = connector.clone();
258        let interceptor = interceptor.clone();
259        let policy = policy.clone();
260        let loop_detector = loop_detector.clone();
261
262        let circuit_breaker = circuit_breaker.clone();
263
264        let engine_index = CONN_COUNTER.fetch_add(1, Ordering::Relaxed);
265
266        let conn_info_2 = conn_info.clone();
267        let interceptor_2 = interceptor.clone();
268        tokio::task::spawn(ENGINE_INDEX.scope(engine_index, async move {
269            let conn_start = Instant::now();
270            let result = http1::Builder::new()
271                .timer(hyper_util::rt::TokioTimer::new())
272                .header_read_timeout(Duration::from_secs(10))
273                .preserve_header_case(true)
274                .title_case_headers(true)
275                .serve_connection(
276                    io,
277                    service_fn(move |req| {
278                        handle_request(
279                            req,
280                            client_addr,
281                            on_flow.clone(),
282                            ca.clone(),
283                            connector.clone(),
284                            interceptor.clone(),
285                            target_addr,
286                            policy.clone(),
287                            loop_detector.clone(),
288                            circuit_breaker.clone(),
289                        )
290                    }),
291                )
292                .with_upgrades()
293                .await;
294
295            let stats = crate::interceptor::ConnectionStats {
296                duration_ms: conn_start.elapsed().as_millis() as u64,
297                // TODO: populate bytes_sent, bytes_received, flows_count when real-time stats tracking is added
298                ..Default::default()
299            };
300            interceptor_2.on_disconnect(&conn_info_2, &stats).await;
301
302            if let Err(err) = result {
303                error!("Error serving connection: {:?}", err);
304            }
305        }));
306    }
307
308    Ok(())
309}