relay_core_lib/proxy/
server.rs1use hyper::server::conn::http1;
2use hyper::service::service_fn;
3use hyper_rustls::HttpsConnectorBuilder;
4use hyper_util::client::legacy::Client;
5use hyper_util::rt::TokioIo;
6use std::sync::Arc;
7use std::sync::atomic::{AtomicUsize, Ordering};
8use std::time::Duration;
9use tokio::sync::{mpsc::Sender, watch};
10use tracing::{error, info};
11
12use crate::capture::CaptureSource;
13use crate::capture::loop_detection::LoopDetector;
14use crate::interceptor::{ENGINE_INDEX, Interceptor};
15use crate::proxy::circuit_breaker::CircuitBreaker;
16use crate::proxy::http::handle_request;
17use crate::proxy::http_utils::HttpsClient;
18use crate::tls::CertificateAuthority;
19use relay_core_api::flow::FlowUpdate;
20use relay_core_api::policy::ProxyPolicy;
21
22static CONN_COUNTER: AtomicUsize = AtomicUsize::new(0);
23
24pub async fn start_proxy<S>(
26 mut source: S,
27 on_flow: Sender<FlowUpdate>,
28 interceptor: Arc<dyn Interceptor>,
29 ca: Arc<CertificateAuthority>,
30 policy: watch::Receiver<ProxyPolicy>,
31 client: Option<Arc<HttpsClient>>,
32 shutdown_rx: Option<tokio::sync::oneshot::Receiver<()>>,
33) -> crate::error::Result<()>
34where
35 S: CaptureSource + Send + 'static,
36 S::IO: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin + Send + 'static,
37{
38 info!("RelayCore Proxy starting...");
39 info!("CA Loaded. Root cert:\n{}", ca.get_ca_cert_pem());
40 info!("Proxy Policy: {:?}", policy.borrow());
41
42 let client = if let Some(c) = client {
44 c
45 } else {
46 let https = HttpsConnectorBuilder::new()
47 .with_native_roots()?
48 .https_or_http()
49 .enable_http1()
50 .enable_http2()
51 .build();
52 let client: HttpsClient = Client::builder(hyper_util::rt::TokioExecutor::new())
53 .timer(hyper_util::rt::TokioTimer::new())
54 .pool_idle_timeout(Duration::from_secs(60))
55 .pool_max_idle_per_host(32)
56 .http2_initial_stream_window_size(2 * 1024 * 1024) .http2_initial_connection_window_size(4 * 1024 * 1024) .http2_keep_alive_interval(Duration::from_secs(20))
59 .http2_keep_alive_timeout(Duration::from_secs(10))
60 .build(https);
61 Arc::new(client)
62 };
63
64 let listen_addrs = source.listen_addrs().into_iter().collect();
66 let loop_detector = Arc::new(LoopDetector::new(listen_addrs));
67 {
68 let loop_detector_bg = loop_detector.clone();
69 tokio::spawn(async move {
70 loop_detector_bg.refresh_local_addrs().await;
72 let mut ticker = tokio::time::interval(Duration::from_secs(60));
73 loop {
74 ticker.tick().await;
75 loop_detector_bg.refresh_local_addrs().await;
76 }
77 });
78 }
79
80 let circuit_breaker = Arc::new(CircuitBreaker::default());
82
83 let mut shutdown_rx = shutdown_rx;
84
85 loop {
86 let connection_result = tokio::select! {
88 res = source.accept() => res,
89 _ = async {
90 if let Some(rx) = shutdown_rx.as_mut() {
91 rx.await.ok();
92 } else {
93 std::future::pending::<()>().await;
94 }
95 } => {
96 info!("RelayCore Proxy received shutdown signal");
97 break;
98 }
99 };
100
101 let connection = match connection_result {
102 Ok(conn) => conn,
103 Err(e) => {
104 error!("Error accepting connection: {}", e);
105 continue;
106 }
107 };
108
109 let stream = connection.stream;
110 let client_addr = connection.client_addr;
111 let target_addr = connection.target_addr;
112
113 let io = TokioIo::new(stream);
114 let on_flow = on_flow.clone();
115 let ca = ca.clone();
116 let client = client.clone();
117 let interceptor = interceptor.clone();
118 let policy = policy.clone();
119 let loop_detector = loop_detector.clone();
120
121 let circuit_breaker = circuit_breaker.clone();
122
123 let engine_index = CONN_COUNTER.fetch_add(1, Ordering::Relaxed);
124
125 tokio::task::spawn(ENGINE_INDEX.scope(engine_index, async move {
126 if let Err(err) = http1::Builder::new()
127 .timer(hyper_util::rt::TokioTimer::new())
128 .header_read_timeout(Duration::from_secs(10))
129 .preserve_header_case(true)
130 .title_case_headers(true)
131 .serve_connection(
132 io,
133 service_fn(move |req| {
134 handle_request(
135 req,
136 client_addr,
137 on_flow.clone(),
138 ca.clone(),
139 client.clone(),
140 interceptor.clone(),
141 target_addr,
142 policy.clone(),
143 loop_detector.clone(),
144 circuit_breaker.clone(),
145 )
146 }),
147 )
148 .with_upgrades()
149 .await
150 {
151 error!("Error serving connection: {:?}", err);
152 }
153 }));
154 }
155
156 Ok(())
157}