Skip to main content

relay_core_lib/proxy/
server.rs

1use hyper::server::conn::http1;
2use hyper::service::service_fn;
3use hyper_rustls::HttpsConnectorBuilder;
4use hyper_util::client::legacy::Client;
5use hyper_util::rt::TokioIo;
6use std::sync::Arc;
7use std::sync::atomic::{AtomicUsize, Ordering};
8use std::time::{Duration, Instant};
9use tokio::sync::{mpsc::Sender, watch};
10use tracing::{error, info};
11use uuid::Uuid;
12
13use crate::capture::CaptureSource;
14use crate::capture::loop_detection::LoopDetector;
15use crate::interceptor::{ConnectAction, ConnectionInfo, ENGINE_INDEX, Interceptor};
16use crate::proxy::circuit_breaker::CircuitBreaker;
17use crate::proxy::http::handle_request;
18use crate::proxy::http_utils::HttpsClient;
19use crate::rule::engine::executor::{ConnectOverride, RuleEngine};
20use crate::tls::CertificateAuthority;
21use chrono::Utc;
22use relay_core_api::flow::{Flow, FlowUpdate, Layer, NetworkInfo, TcpLayer, TransportProtocol};
23use relay_core_api::policy::ProxyPolicy;
24use relay_core_api::rule::RuleStage;
25use std::collections::HashMap;
26use std::net::{IpAddr, SocketAddr};
27
28static CONN_COUNTER: AtomicUsize = AtomicUsize::new(0);
29
30/// Start the HTTP Proxy Server
31#[allow(clippy::too_many_arguments)]
32pub async fn start_proxy<S>(
33    mut source: S,
34    on_flow: Sender<FlowUpdate>,
35    interceptor: Arc<dyn Interceptor>,
36    ca: Arc<CertificateAuthority>,
37    policy: watch::Receiver<ProxyPolicy>,
38    client: Option<Arc<HttpsClient>>,
39    shutdown_rx: Option<tokio::sync::oneshot::Receiver<()>>,
40    rule_engine: Option<Arc<RuleEngine>>,
41) -> crate::error::Result<()>
42where
43    S: CaptureSource + Send + 'static,
44    S::IO: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin + Send + 'static,
45{
46    info!("RelayCore Proxy starting...");
47    info!("CA Loaded. Root cert:\n{}", ca.get_ca_cert_pem());
48    info!("Proxy Policy: {:?}", policy.borrow());
49
50    // Initialize HTTP Client (shared)
51    let client = if let Some(c) = client {
52        c
53    } else {
54        let https = HttpsConnectorBuilder::new()
55            .with_native_roots()?
56            .https_or_http()
57            .enable_http1()
58            .enable_http2()
59            .build();
60        let client: HttpsClient = Client::builder(hyper_util::rt::TokioExecutor::new())
61            .timer(hyper_util::rt::TokioTimer::new())
62            .pool_idle_timeout(Duration::from_secs(60))
63            .pool_max_idle_per_host(32)
64            .http2_initial_stream_window_size(2 * 1024 * 1024) // 2MB
65            .http2_initial_connection_window_size(4 * 1024 * 1024) // 4MB
66            .http2_keep_alive_interval(Duration::from_secs(20))
67            .http2_keep_alive_timeout(Duration::from_secs(10))
68            .build(https);
69        Arc::new(client)
70    };
71
72    // Initialize Loop Detector
73    let listen_addrs = source.listen_addrs().into_iter().collect();
74    let loop_detector = Arc::new(LoopDetector::new(listen_addrs));
75    {
76        let loop_detector_bg = loop_detector.clone();
77        tokio::spawn(async move {
78            // Prime local interface cache at startup.
79            loop_detector_bg.refresh_local_addrs().await;
80            let mut ticker = tokio::time::interval(Duration::from_secs(60));
81            loop {
82                ticker.tick().await;
83                loop_detector_bg.refresh_local_addrs().await;
84            }
85        });
86    }
87
88    // Initialize Circuit Breaker (P3)
89    let circuit_breaker = Arc::new(CircuitBreaker::default());
90
91    let mut shutdown_rx = shutdown_rx;
92
93    loop {
94        // Accept connection from abstract source or handle shutdown
95        let connection_result = tokio::select! {
96            res = source.accept() => res,
97            _ = async {
98                if let Some(rx) = shutdown_rx.as_mut() {
99                    rx.await.ok();
100                } else {
101                    std::future::pending::<()>().await;
102                }
103            } => {
104                info!("RelayCore Proxy received shutdown signal");
105                break;
106            }
107        };
108
109        let connection = match connection_result {
110            Ok(conn) => conn,
111            Err(e) => {
112                error!("Error accepting connection: {}", e);
113                continue;
114            }
115        };
116
117        let stream = connection.stream;
118        let client_addr = connection.client_addr;
119        let target_addr = connection.target_addr;
120
121        let conn_id = Uuid::new_v4();
122        let conn_info = ConnectionInfo {
123            id: conn_id,
124            client_addr,
125            server_addr: target_addr,
126            // TODO: extract SNI from TLS ClientHello (requires pre-rustls intercept)
127            tls_sni: None,
128        };
129
130        match interceptor.on_connect(&conn_info).await {
131            ConnectAction::Drop { reason } => {
132                info!("Connection {} dropped by onConnect: {}", conn_id, reason);
133                interceptor
134                    .on_disconnect(&conn_info, &Default::default())
135                    .await;
136                continue;
137            }
138            ConnectAction::Allow => {}
139        }
140
141        let mut connect_target = target_addr;
142        if let Some(ref engine) = rule_engine
143            && engine.has_rules_for_stage(RuleStage::Connect)
144        {
145            let mut flow = Flow {
146                id: conn_id,
147                start_time: Utc::now(),
148                end_time: None,
149                network: NetworkInfo {
150                    client_ip: client_addr.ip().to_string(),
151                    client_port: client_addr.port(),
152                    server_ip: target_addr.map(|a| a.ip().to_string()).unwrap_or_default(),
153                    server_port: target_addr.map(|a| a.port()).unwrap_or(0),
154                    protocol: TransportProtocol::TCP,
155                    tls: false,
156                    tls_version: None,
157                    sni: None,
158                },
159                layer: Layer::Tcp(TcpLayer {
160                    bytes_up: 0,
161                    bytes_down: 0,
162                }),
163                tags: vec![],
164                meta: HashMap::new(),
165                resilience_trace: None,
166                rule_variables: HashMap::new(),
167                matched_rules: vec![],
168            };
169            let ctx = engine.execute(RuleStage::Connect, &mut flow).await;
170            if ctx.is_terminated() {
171                info!("Connection {} terminated by Connect stage rule", conn_id);
172                interceptor
173                    .on_disconnect(&conn_info, &Default::default())
174                    .await;
175                continue;
176            }
177            if let Some(conn_override) = &ctx.connect_override {
178                match conn_override {
179                    ConnectOverride::ForwardPort { host: _, port } => {
180                        connect_target = Some(SocketAddr::new(
181                            target_addr
182                                .map(|a| a.ip())
183                                .unwrap_or(IpAddr::from([0, 0, 0, 0])),
184                            *port,
185                        ));
186                        tracing::debug!("Connect stage ForwardPort -> port {}", port);
187                    }
188                    ConnectOverride::RedirectIp { ip } => {
189                        let port = connect_target.map(|a| a.port()).unwrap_or(0);
190                        connect_target = Some(SocketAddr::new(*ip, port));
191                        tracing::debug!("Connect stage RedirectIp -> {}", ip);
192                    }
193                    ConnectOverride::SetTtl { ttl } => {
194                        // Socket-level TTL application requires access to the raw
195                        // TcpStream before it enters the hyper HTTP stack, which is
196                        // not feasible with the current connection model.
197                        // Deferred to 1.x.
198                        tracing::warn!(
199                            "Connect stage SetTtl({}) is not yet implemented; TTL unchanged",
200                            ttl
201                        );
202                    }
203                }
204            }
205        }
206
207        let target_addr = connect_target;
208
209        let io = TokioIo::new(stream);
210        let on_flow = on_flow.clone();
211        let ca = ca.clone();
212        let client = client.clone();
213        let interceptor = interceptor.clone();
214        let policy = policy.clone();
215        let loop_detector = loop_detector.clone();
216
217        let circuit_breaker = circuit_breaker.clone();
218
219        let engine_index = CONN_COUNTER.fetch_add(1, Ordering::Relaxed);
220
221        let conn_info_2 = conn_info.clone();
222        let interceptor_2 = interceptor.clone();
223        tokio::task::spawn(ENGINE_INDEX.scope(engine_index, async move {
224            let conn_start = Instant::now();
225            let result = http1::Builder::new()
226                .timer(hyper_util::rt::TokioTimer::new())
227                .header_read_timeout(Duration::from_secs(10))
228                .preserve_header_case(true)
229                .title_case_headers(true)
230                .serve_connection(
231                    io,
232                    service_fn(move |req| {
233                        handle_request(
234                            req,
235                            client_addr,
236                            on_flow.clone(),
237                            ca.clone(),
238                            client.clone(),
239                            interceptor.clone(),
240                            target_addr,
241                            policy.clone(),
242                            loop_detector.clone(),
243                            circuit_breaker.clone(),
244                        )
245                    }),
246                )
247                .with_upgrades()
248                .await;
249
250            let stats = crate::interceptor::ConnectionStats {
251                duration_ms: conn_start.elapsed().as_millis() as u64,
252                // TODO: populate bytes_sent, bytes_received, flows_count when real-time stats tracking is added
253                ..Default::default()
254            };
255            interceptor_2.on_disconnect(&conn_info_2, &stats).await;
256
257            if let Err(err) = result {
258                error!("Error serving connection: {:?}", err);
259            }
260        }));
261    }
262
263    Ok(())
264}