Skip to main content

relay_core_lib/proxy/
server.rs

1use std::sync::Arc;
2use std::time::Duration;
3use std::sync::atomic::{AtomicUsize, Ordering};
4use tokio::sync::{mpsc::Sender, watch};
5use hyper_util::rt::TokioIo;
6use hyper::server::conn::http1;
7use hyper::service::service_fn;
8use hyper_util::client::legacy::Client;
9use hyper_rustls::HttpsConnectorBuilder;
10use tracing::{info, error};
11
12use crate::capture::CaptureSource; 
13use crate::tls::CertificateAuthority;
14use crate::interceptor::{Interceptor, ENGINE_INDEX};
15use relay_core_api::flow::FlowUpdate;
16use relay_core_api::policy::ProxyPolicy;
17use crate::proxy::http::handle_request;
18use crate::proxy::http_utils::HttpsClient;
19use crate::capture::loop_detection::LoopDetector;
20
21static CONN_COUNTER: AtomicUsize = AtomicUsize::new(0);
22
23/// Start the HTTP Proxy Server
24pub async fn start_proxy<S>(
25    mut source: S,
26    on_flow: Sender<FlowUpdate>,
27    interceptor: Arc<dyn Interceptor>,
28    ca: Arc<CertificateAuthority>,
29    policy: watch::Receiver<ProxyPolicy>,
30    client: Option<Arc<HttpsClient>>,
31    shutdown_rx: Option<tokio::sync::oneshot::Receiver<()>>,
32) -> crate::error::Result<()> 
33where
34    S: CaptureSource + Send + 'static,
35    S::IO: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin + Send + 'static,
36{
37    info!("RelayCore Proxy starting...");
38    info!("CA Loaded. Root cert:\n{}", ca.get_ca_cert_pem());
39    info!("Proxy Policy: {:?}", policy.borrow());
40
41    // Initialize HTTP Client (shared)
42    let client = if let Some(c) = client {
43        c
44    } else {
45        let https = HttpsConnectorBuilder::new()
46            .with_native_roots()?
47            .https_or_http()
48            .enable_http1()
49            .enable_http2()
50            .build();
51        let client: HttpsClient = Client::builder(hyper_util::rt::TokioExecutor::new())
52            .timer(hyper_util::rt::TokioTimer::new())
53            .pool_idle_timeout(Duration::from_secs(60))
54            .pool_max_idle_per_host(32)
55            .http2_initial_stream_window_size(2 * 1024 * 1024) // 2MB
56            .http2_initial_connection_window_size(4 * 1024 * 1024) // 4MB
57            .http2_keep_alive_interval(Duration::from_secs(20))
58            .http2_keep_alive_timeout(Duration::from_secs(10))
59            .build(https);
60        Arc::new(client)
61    };
62
63    // Initialize Loop Detector
64    let listen_addrs = source.listen_addrs().into_iter().collect();
65    let loop_detector = Arc::new(LoopDetector::new(listen_addrs));
66    {
67        let loop_detector_bg = loop_detector.clone();
68        tokio::spawn(async move {
69            // Prime local interface cache at startup.
70            loop_detector_bg.refresh_local_addrs().await;
71            let mut ticker = tokio::time::interval(Duration::from_secs(60));
72            loop {
73                ticker.tick().await;
74                loop_detector_bg.refresh_local_addrs().await;
75            }
76        });
77    }
78
79    let mut shutdown_rx = shutdown_rx;
80
81    loop {
82        // Accept connection from abstract source or handle shutdown
83        let connection_result = tokio::select! {
84            res = source.accept() => res,
85            _ = async {
86                if let Some(rx) = shutdown_rx.as_mut() {
87                    rx.await.ok();
88                } else {
89                    std::future::pending::<()>().await;
90                }
91            } => {
92                info!("RelayCore Proxy received shutdown signal");
93                break;
94            }
95        };
96
97        let connection = match connection_result {
98            Ok(conn) => conn,
99            Err(e) => {
100                error!("Error accepting connection: {}", e);
101                continue;
102            }
103        };
104
105        let stream = connection.stream;
106        let client_addr = connection.client_addr;
107        let target_addr = connection.target_addr;
108
109        let io = TokioIo::new(stream);
110        let on_flow = on_flow.clone();
111        let ca = ca.clone();
112        let client = client.clone();
113        let interceptor = interceptor.clone();
114        let policy = policy.clone();
115        let loop_detector = loop_detector.clone();
116
117        let engine_index = CONN_COUNTER.fetch_add(1, Ordering::Relaxed);
118
119        tokio::task::spawn(ENGINE_INDEX.scope(engine_index, async move {
120            if let Err(err) = http1::Builder::new()
121                .timer(hyper_util::rt::TokioTimer::new())
122                .header_read_timeout(Duration::from_secs(10))
123                .preserve_header_case(true)
124                .title_case_headers(true)
125                .serve_connection(io, service_fn(move |req| handle_request(
126                    req, 
127                    client_addr, 
128                    on_flow.clone(), 
129                    ca.clone(), 
130                    client.clone(), 
131                    interceptor.clone(),
132                    target_addr,
133                    policy.clone(),
134                    loop_detector.clone()
135                )))
136                .with_upgrades()
137                .await
138            {
139                error!("Error serving connection: {:?}", err);
140            }
141        }));
142    }
143
144    Ok(())
145}