plane_dynamic_proxy/
server.rs

1use crate::{
2    body::SimpleBody, graceful_shutdown::GracefulShutdown, https_redirect::HttpsRedirectService,
3};
4use anyhow::Result;
5use http::HeaderValue;
6use hyper::{body::Incoming, service::Service, Request, Response};
7use hyper_util::{
8    rt::{TokioExecutor, TokioIo},
9    server::conn::auto::Builder as ServerBuilder,
10};
11use rustls::{server::ResolvesServerCert, ServerConfig};
12use std::{
13    net::{IpAddr, SocketAddr},
14    sync::Arc,
15    time::Duration,
16};
17use tokio::{net::TcpListener, select};
18use tokio_rustls::TlsAcceptor;
19
20/// Header which passes the client's IP address to the backend.
21const X_FORWARDED_FOR: &str = "x-forwarded-for";
22
23/// Header which passes the client's protocol (http or https) to the backend.
24const X_FORWARDED_PROTO: &str = "x-forwarded-proto";
25
26/// A simple server that wraps a hyper service and handles requests.
27/// The server can be configured to listen for either HTTP and HTTPS,
28/// and supports graceful shutdown and x-forwarded-* headers.
29pub struct SimpleHttpServer {
30    handle: tokio::task::JoinHandle<()>,
31    graceful_shutdown: Option<GracefulShutdown>,
32}
33
34async fn listen_loop<S>(listener: TcpListener, service: S, graceful_shutdown: GracefulShutdown)
35where
36    S: Service<Request<Incoming>, Response = Response<SimpleBody>> + Clone + Send + 'static,
37    S::Future: Send + 'static,
38    S::Error: Into<Box<dyn std::error::Error + Send + Sync>>,
39{
40    let mut recv = graceful_shutdown.subscribe();
41
42    loop {
43        let stream = select! {
44            stream = listener.accept() => stream,
45            _ = recv.changed() => break,
46        };
47
48        let (stream, remote_addr) = match stream {
49            Ok((stream, remote_addr)) => (stream, remote_addr),
50            Err(e) => {
51                tracing::warn!(?e, "Failed to accept connection.");
52                continue;
53            }
54        };
55        let remote_ip = remote_addr.ip();
56        let service = WrappedService::new(service.clone(), remote_ip, "http");
57
58        let server = ServerBuilder::new(TokioExecutor::new());
59        let io = TokioIo::new(stream);
60        let conn = server.serve_connection_with_upgrades(io, service);
61
62        let conn = graceful_shutdown.watch(conn.into_owned());
63        tokio::spawn(async {
64            if let Err(e) = conn.await {
65                tracing::warn!(?e, "Failed to serve connection.");
66            }
67        });
68    }
69}
70
71async fn listen_loop_tls<S>(
72    listener: TcpListener,
73    service: S,
74    resolver: Arc<dyn ResolvesServerCert>,
75    graceful_shutdown: GracefulShutdown,
76) where
77    S: Service<Request<Incoming>, Response = Response<SimpleBody>> + Clone + Send + 'static,
78    S::Future: Send + 'static,
79    S::Error: Into<Box<dyn std::error::Error + Send + Sync>>,
80{
81    let server_config = ServerConfig::builder()
82        .with_no_client_auth()
83        .with_cert_resolver(resolver);
84    let tls_acceptor = TlsAcceptor::from(Arc::new(server_config));
85    let mut recv = graceful_shutdown.subscribe();
86
87    loop {
88        let stream = select! {
89            stream = listener.accept() => stream,
90            _ = recv.changed() => break,
91        };
92
93        let (stream, remote_addr) = match stream {
94            Ok((stream, remote_addr)) => (stream, remote_addr),
95            Err(e) => {
96                tracing::warn!(?e, "Failed to accept connection.");
97                continue;
98            }
99        };
100        let remote_ip = remote_addr.ip();
101        let service = WrappedService::new(service.clone(), remote_ip, "https");
102        let tls_acceptor = tls_acceptor.clone();
103
104        let graceful_shutdown = graceful_shutdown.clone();
105        tokio::spawn(async move {
106            let server = ServerBuilder::new(TokioExecutor::new());
107
108            let stream = match tls_acceptor.accept(stream).await {
109                Ok(stream) => stream,
110                Err(e) => {
111                    tracing::warn!(?e, "Failed to accept TLS connection.");
112                    return;
113                }
114            };
115            let io = TokioIo::new(stream);
116
117            let conn = server.serve_connection_with_upgrades(io, service);
118            let conn = graceful_shutdown.watch(conn.into_owned());
119
120            if let Err(e) = conn.await {
121                tracing::warn!(?e, "Failed to serve connection.");
122            }
123        });
124    }
125}
126
127pub enum HttpsConfig {
128    Http,
129    Https {
130        resolver: Arc<dyn ResolvesServerCert>,
131    },
132}
133
134impl HttpsConfig {
135    pub fn from_resolver<R: ResolvesServerCert + 'static>(resolver: R) -> Self {
136        Self::Https {
137            resolver: Arc::new(resolver),
138        }
139    }
140
141    pub fn http() -> Self {
142        Self::Http
143    }
144}
145
146impl SimpleHttpServer {
147    pub fn new<S>(service: S, listener: TcpListener, https_config: HttpsConfig) -> Result<Self>
148    where
149        S: Service<Request<Incoming>, Response = Response<SimpleBody>> + Clone + Send + 'static,
150        S::Future: Send + 'static,
151        S::Error: Into<Box<dyn std::error::Error + Send + Sync>>,
152    {
153        let graceful_shutdown = GracefulShutdown::new();
154
155        let handle = match https_config {
156            HttpsConfig::Http => {
157                tokio::spawn(listen_loop(listener, service, graceful_shutdown.clone()))
158            }
159            HttpsConfig::Https { resolver } => {
160                if rustls::crypto::ring::default_provider()
161                    .install_default()
162                    .is_err()
163                {
164                    tracing::info!("Using already-installed crypto provider.")
165                }
166
167                tokio::spawn(listen_loop_tls(
168                    listener,
169                    service,
170                    resolver,
171                    graceful_shutdown.clone(),
172                ))
173            }
174        };
175
176        Ok(Self {
177            handle,
178            graceful_shutdown: Some(graceful_shutdown),
179        })
180    }
181
182    pub async fn graceful_shutdown(mut self) {
183        println!("Shutting down");
184        let graceful_shutdown = self
185            .graceful_shutdown
186            .take()
187            .expect("self.graceful_shutdown is always set");
188        graceful_shutdown.shutdown().await;
189    }
190
191    pub async fn graceful_shutdown_with_timeout(mut self, timeout: Duration) {
192        let graceful_shutdown = self
193            .graceful_shutdown
194            .take()
195            .expect("self.graceful_shutdown is always set");
196        let result = tokio::time::timeout(timeout, graceful_shutdown.shutdown()).await;
197
198        if let Err(e) = result {
199            tracing::warn!(?e, "Timed out waiting for graceful shutdown, aborting.");
200        }
201    }
202}
203
204impl Drop for SimpleHttpServer {
205    fn drop(&mut self) {
206        if self.graceful_shutdown.is_some() {
207            tracing::warn!("Shutting down SimpleHttpServer without a call to graceful_shutdown. Connections will be dropped abruptly!");
208        }
209
210        self.handle.abort();
211    }
212}
213
214pub struct ServerWithHttpRedirect {
215    http_server: SimpleHttpServer,
216    https_server: Option<SimpleHttpServer>,
217}
218
219pub struct ServerWithHttpRedirectHttpsConfig {
220    pub https_port: u16,
221    pub resolver: Arc<dyn ResolvesServerCert>,
222}
223
224pub struct ServerWithHttpRedirectConfig {
225    pub http_port: u16,
226    pub https_config: Option<ServerWithHttpRedirectHttpsConfig>,
227}
228
229impl ServerWithHttpRedirect {
230    pub async fn new<S>(service: S, server_config: ServerWithHttpRedirectConfig) -> Result<Self>
231    where
232        S: Service<Request<Incoming>, Response = Response<SimpleBody>>
233            + Clone
234            + Send
235            + Sync
236            + 'static,
237        S::Future: Send + 'static,
238        S::Error: Into<Box<dyn std::error::Error + Send + Sync>>,
239    {
240        if let Some(https_config) = server_config.https_config {
241            // Serve HTTPS
242            let https_listener =
243                TcpListener::bind(SocketAddr::from(([0, 0, 0, 0], https_config.https_port)))
244                    .await?;
245            let https_server = SimpleHttpServer::new(
246                service,
247                https_listener,
248                HttpsConfig::Https {
249                    resolver: https_config.resolver,
250                },
251            )?;
252
253            // Redirect HTTP to HTTPS
254            let http_listener =
255                TcpListener::bind(SocketAddr::from(([0, 0, 0, 0], server_config.http_port)))
256                    .await?;
257            let http_server =
258                SimpleHttpServer::new(HttpsRedirectService, http_listener, HttpsConfig::Http)?;
259
260            Ok(Self {
261                http_server,
262                https_server: Some(https_server),
263            })
264        } else {
265            let listener =
266                TcpListener::bind(SocketAddr::from(([0, 0, 0, 0], server_config.http_port)))
267                    .await?;
268            let http_server = SimpleHttpServer::new(service, listener, HttpsConfig::Http)?;
269
270            Ok(Self {
271                http_server,
272                https_server: None,
273            })
274        }
275    }
276
277    pub async fn graceful_shutdown_with_timeout(self, timeout: Duration) {
278        if let Some(https_server) = self.https_server {
279            tokio::join!(
280                self.http_server.graceful_shutdown_with_timeout(timeout),
281                https_server.graceful_shutdown_with_timeout(timeout)
282            );
283        } else {
284            self.http_server
285                .graceful_shutdown_with_timeout(timeout)
286                .await;
287        }
288    }
289}
290
291/// A service that wraps another service and sets
292/// X-Forwarded-For and X-Forwarded-Proto headers.
293struct WrappedService<S> {
294    inner: S,
295    forwarded_for: IpAddr,
296    forwarded_proto: &'static str,
297}
298
299impl<S> WrappedService<S> {
300    pub fn new(inner: S, forwarded_for: IpAddr, forwarded_proto: &'static str) -> Self {
301        Self {
302            inner,
303            forwarded_for,
304            forwarded_proto,
305        }
306    }
307}
308
309impl<S, ReqBody, ResBody> Service<Request<ReqBody>> for WrappedService<S>
310where
311    S: Service<Request<ReqBody>, Response = Response<ResBody>>,
312{
313    type Response = S::Response;
314    type Error = S::Error;
315    type Future = S::Future;
316
317    fn call(&self, request: Request<ReqBody>) -> Self::Future {
318        let mut request = request;
319        request.headers_mut().insert(
320            X_FORWARDED_FOR,
321            HeaderValue::from_str(&format!("{}", self.forwarded_for))
322                .expect("X-Forwarded-For is always valid"),
323        );
324        request.headers_mut().insert(
325            X_FORWARDED_PROTO,
326            HeaderValue::from_str(self.forwarded_proto).expect("X-Forwarded-Proto is always valid"),
327        );
328        self.inner.call(request)
329    }
330}