Skip to main content

relay_core_lib/proxy/
server.rs

1use hyper::server::conn::http1;
2use hyper::service::service_fn;
3use hyper_rustls::HttpsConnectorBuilder;
4use hyper_util::client::legacy::Client;
5use hyper_util::rt::TokioIo;
6use std::sync::Arc;
7use std::sync::atomic::{AtomicUsize, Ordering};
8use std::time::Duration;
9use tokio::sync::{mpsc::Sender, watch};
10use tracing::{error, info};
11
12use crate::capture::CaptureSource;
13use crate::capture::loop_detection::LoopDetector;
14use crate::interceptor::{ENGINE_INDEX, Interceptor};
15use crate::proxy::http::handle_request;
16use crate::proxy::http_utils::HttpsClient;
17use crate::tls::CertificateAuthority;
18use relay_core_api::flow::FlowUpdate;
19use relay_core_api::policy::ProxyPolicy;
20
21static CONN_COUNTER: AtomicUsize = AtomicUsize::new(0);
22
23/// Start the HTTP Proxy Server
24pub async fn start_proxy<S>(
25    mut source: S,
26    on_flow: Sender<FlowUpdate>,
27    interceptor: Arc<dyn Interceptor>,
28    ca: Arc<CertificateAuthority>,
29    policy: watch::Receiver<ProxyPolicy>,
30    client: Option<Arc<HttpsClient>>,
31    shutdown_rx: Option<tokio::sync::oneshot::Receiver<()>>,
32) -> crate::error::Result<()>
33where
34    S: CaptureSource + Send + 'static,
35    S::IO: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin + Send + 'static,
36{
37    info!("RelayCore Proxy starting...");
38    info!("CA Loaded. Root cert:\n{}", ca.get_ca_cert_pem());
39    info!("Proxy Policy: {:?}", policy.borrow());
40
41    // Initialize HTTP Client (shared)
42    let client = if let Some(c) = client {
43        c
44    } else {
45        let https = HttpsConnectorBuilder::new()
46            .with_native_roots()?
47            .https_or_http()
48            .enable_http1()
49            .enable_http2()
50            .build();
51        let client: HttpsClient = Client::builder(hyper_util::rt::TokioExecutor::new())
52            .timer(hyper_util::rt::TokioTimer::new())
53            .pool_idle_timeout(Duration::from_secs(60))
54            .pool_max_idle_per_host(32)
55            .http2_initial_stream_window_size(2 * 1024 * 1024) // 2MB
56            .http2_initial_connection_window_size(4 * 1024 * 1024) // 4MB
57            .http2_keep_alive_interval(Duration::from_secs(20))
58            .http2_keep_alive_timeout(Duration::from_secs(10))
59            .build(https);
60        Arc::new(client)
61    };
62
63    // Initialize Loop Detector
64    let listen_addrs = source.listen_addrs().into_iter().collect();
65    let loop_detector = Arc::new(LoopDetector::new(listen_addrs));
66    {
67        let loop_detector_bg = loop_detector.clone();
68        tokio::spawn(async move {
69            // Prime local interface cache at startup.
70            loop_detector_bg.refresh_local_addrs().await;
71            let mut ticker = tokio::time::interval(Duration::from_secs(60));
72            loop {
73                ticker.tick().await;
74                loop_detector_bg.refresh_local_addrs().await;
75            }
76        });
77    }
78
79    let mut shutdown_rx = shutdown_rx;
80
81    loop {
82        // Accept connection from abstract source or handle shutdown
83        let connection_result = tokio::select! {
84            res = source.accept() => res,
85            _ = async {
86                if let Some(rx) = shutdown_rx.as_mut() {
87                    rx.await.ok();
88                } else {
89                    std::future::pending::<()>().await;
90                }
91            } => {
92                info!("RelayCore Proxy received shutdown signal");
93                break;
94            }
95        };
96
97        let connection = match connection_result {
98            Ok(conn) => conn,
99            Err(e) => {
100                error!("Error accepting connection: {}", e);
101                continue;
102            }
103        };
104
105        let stream = connection.stream;
106        let client_addr = connection.client_addr;
107        let target_addr = connection.target_addr;
108
109        let io = TokioIo::new(stream);
110        let on_flow = on_flow.clone();
111        let ca = ca.clone();
112        let client = client.clone();
113        let interceptor = interceptor.clone();
114        let policy = policy.clone();
115        let loop_detector = loop_detector.clone();
116
117        let engine_index = CONN_COUNTER.fetch_add(1, Ordering::Relaxed);
118
119        tokio::task::spawn(ENGINE_INDEX.scope(engine_index, async move {
120            if let Err(err) = http1::Builder::new()
121                .timer(hyper_util::rt::TokioTimer::new())
122                .header_read_timeout(Duration::from_secs(10))
123                .preserve_header_case(true)
124                .title_case_headers(true)
125                .serve_connection(
126                    io,
127                    service_fn(move |req| {
128                        handle_request(
129                            req,
130                            client_addr,
131                            on_flow.clone(),
132                            ca.clone(),
133                            client.clone(),
134                            interceptor.clone(),
135                            target_addr,
136                            policy.clone(),
137                            loop_detector.clone(),
138                        )
139                    }),
140                )
141                .with_upgrades()
142                .await
143            {
144                error!("Error serving connection: {:?}", err);
145            }
146        }));
147    }
148
149    Ok(())
150}