Skip to main content

relay_core_lib/proxy/
server.rs

1use hyper::server::conn::http1;
2use hyper::service::service_fn;
3use hyper_rustls::HttpsConnectorBuilder;
4use hyper_util::client::legacy::Client;
5use hyper_util::rt::TokioIo;
6use std::sync::Arc;
7use std::sync::atomic::{AtomicUsize, Ordering};
8use std::time::Duration;
9use tokio::sync::{mpsc::Sender, watch};
10use tracing::{error, info};
11
12use crate::capture::CaptureSource;
13use crate::capture::loop_detection::LoopDetector;
14use crate::interceptor::{ENGINE_INDEX, Interceptor};
15use crate::proxy::circuit_breaker::CircuitBreaker;
16use crate::proxy::http::handle_request;
17use crate::proxy::http_utils::HttpsClient;
18use crate::tls::CertificateAuthority;
19use relay_core_api::flow::FlowUpdate;
20use relay_core_api::policy::ProxyPolicy;
21
22static CONN_COUNTER: AtomicUsize = AtomicUsize::new(0);
23
24/// Start the HTTP Proxy Server
25pub async fn start_proxy<S>(
26    mut source: S,
27    on_flow: Sender<FlowUpdate>,
28    interceptor: Arc<dyn Interceptor>,
29    ca: Arc<CertificateAuthority>,
30    policy: watch::Receiver<ProxyPolicy>,
31    client: Option<Arc<HttpsClient>>,
32    shutdown_rx: Option<tokio::sync::oneshot::Receiver<()>>,
33) -> crate::error::Result<()>
34where
35    S: CaptureSource + Send + 'static,
36    S::IO: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin + Send + 'static,
37{
38    info!("RelayCore Proxy starting...");
39    info!("CA Loaded. Root cert:\n{}", ca.get_ca_cert_pem());
40    info!("Proxy Policy: {:?}", policy.borrow());
41
42    // Initialize HTTP Client (shared)
43    let client = if let Some(c) = client {
44        c
45    } else {
46        let https = HttpsConnectorBuilder::new()
47            .with_native_roots()?
48            .https_or_http()
49            .enable_http1()
50            .enable_http2()
51            .build();
52        let client: HttpsClient = Client::builder(hyper_util::rt::TokioExecutor::new())
53            .timer(hyper_util::rt::TokioTimer::new())
54            .pool_idle_timeout(Duration::from_secs(60))
55            .pool_max_idle_per_host(32)
56            .http2_initial_stream_window_size(2 * 1024 * 1024) // 2MB
57            .http2_initial_connection_window_size(4 * 1024 * 1024) // 4MB
58            .http2_keep_alive_interval(Duration::from_secs(20))
59            .http2_keep_alive_timeout(Duration::from_secs(10))
60            .build(https);
61        Arc::new(client)
62    };
63
64    // Initialize Loop Detector
65    let listen_addrs = source.listen_addrs().into_iter().collect();
66    let loop_detector = Arc::new(LoopDetector::new(listen_addrs));
67    {
68        let loop_detector_bg = loop_detector.clone();
69        tokio::spawn(async move {
70            // Prime local interface cache at startup.
71            loop_detector_bg.refresh_local_addrs().await;
72            let mut ticker = tokio::time::interval(Duration::from_secs(60));
73            loop {
74                ticker.tick().await;
75                loop_detector_bg.refresh_local_addrs().await;
76            }
77        });
78    }
79
80    // Initialize Circuit Breaker (P3)
81    let circuit_breaker = Arc::new(CircuitBreaker::default());
82
83    let mut shutdown_rx = shutdown_rx;
84
85    loop {
86        // Accept connection from abstract source or handle shutdown
87        let connection_result = tokio::select! {
88            res = source.accept() => res,
89            _ = async {
90                if let Some(rx) = shutdown_rx.as_mut() {
91                    rx.await.ok();
92                } else {
93                    std::future::pending::<()>().await;
94                }
95            } => {
96                info!("RelayCore Proxy received shutdown signal");
97                break;
98            }
99        };
100
101        let connection = match connection_result {
102            Ok(conn) => conn,
103            Err(e) => {
104                error!("Error accepting connection: {}", e);
105                continue;
106            }
107        };
108
109        let stream = connection.stream;
110        let client_addr = connection.client_addr;
111        let target_addr = connection.target_addr;
112
113        let io = TokioIo::new(stream);
114        let on_flow = on_flow.clone();
115        let ca = ca.clone();
116        let client = client.clone();
117        let interceptor = interceptor.clone();
118        let policy = policy.clone();
119        let loop_detector = loop_detector.clone();
120
121        let circuit_breaker = circuit_breaker.clone();
122
123        let engine_index = CONN_COUNTER.fetch_add(1, Ordering::Relaxed);
124
125        tokio::task::spawn(ENGINE_INDEX.scope(engine_index, async move {
126            if let Err(err) = http1::Builder::new()
127                .timer(hyper_util::rt::TokioTimer::new())
128                .header_read_timeout(Duration::from_secs(10))
129                .preserve_header_case(true)
130                .title_case_headers(true)
131                .serve_connection(
132                    io,
133                    service_fn(move |req| {
134                        handle_request(
135                            req,
136                            client_addr,
137                            on_flow.clone(),
138                            ca.clone(),
139                            client.clone(),
140                            interceptor.clone(),
141                            target_addr,
142                            policy.clone(),
143                            loop_detector.clone(),
144                            circuit_breaker.clone(),
145                        )
146                    }),
147                )
148                .with_upgrades()
149                .await
150            {
151                error!("Error serving connection: {:?}", err);
152            }
153        }));
154    }
155
156    Ok(())
157}