1use crate::{
2 body::SimpleBody, graceful_shutdown::GracefulShutdown, https_redirect::HttpsRedirectService,
3};
4use anyhow::Result;
5use http::HeaderValue;
6use hyper::{body::Incoming, service::Service, Request, Response};
7use hyper_util::{
8 rt::{TokioExecutor, TokioIo},
9 server::conn::auto::Builder as ServerBuilder,
10};
11use rustls::{server::ResolvesServerCert, ServerConfig};
12use std::{
13 net::{IpAddr, SocketAddr},
14 sync::Arc,
15 time::Duration,
16};
17use tokio::{net::TcpListener, select};
18use tokio_rustls::TlsAcceptor;
19
20const X_FORWARDED_FOR: &str = "x-forwarded-for";
22
23const X_FORWARDED_PROTO: &str = "x-forwarded-proto";
25
26pub struct SimpleHttpServer {
30 handle: tokio::task::JoinHandle<()>,
31 graceful_shutdown: Option<GracefulShutdown>,
32}
33
34async fn listen_loop<S>(listener: TcpListener, service: S, graceful_shutdown: GracefulShutdown)
35where
36 S: Service<Request<Incoming>, Response = Response<SimpleBody>> + Clone + Send + 'static,
37 S::Future: Send + 'static,
38 S::Error: Into<Box<dyn std::error::Error + Send + Sync>>,
39{
40 let mut recv = graceful_shutdown.subscribe();
41
42 loop {
43 let stream = select! {
44 stream = listener.accept() => stream,
45 _ = recv.changed() => break,
46 };
47
48 let (stream, remote_addr) = match stream {
49 Ok((stream, remote_addr)) => (stream, remote_addr),
50 Err(e) => {
51 tracing::warn!(?e, "Failed to accept connection.");
52 continue;
53 }
54 };
55 let remote_ip = remote_addr.ip();
56 let service = WrappedService::new(service.clone(), remote_ip, "http");
57
58 let server = ServerBuilder::new(TokioExecutor::new());
59 let io = TokioIo::new(stream);
60 let conn = server.serve_connection_with_upgrades(io, service);
61
62 let conn = graceful_shutdown.watch(conn.into_owned());
63 tokio::spawn(async {
64 if let Err(e) = conn.await {
65 tracing::warn!(?e, "Failed to serve connection.");
66 }
67 });
68 }
69}
70
71async fn listen_loop_tls<S>(
72 listener: TcpListener,
73 service: S,
74 resolver: Arc<dyn ResolvesServerCert>,
75 graceful_shutdown: GracefulShutdown,
76) where
77 S: Service<Request<Incoming>, Response = Response<SimpleBody>> + Clone + Send + 'static,
78 S::Future: Send + 'static,
79 S::Error: Into<Box<dyn std::error::Error + Send + Sync>>,
80{
81 let server_config = ServerConfig::builder()
82 .with_no_client_auth()
83 .with_cert_resolver(resolver);
84 let tls_acceptor = TlsAcceptor::from(Arc::new(server_config));
85 let mut recv = graceful_shutdown.subscribe();
86
87 loop {
88 let stream = select! {
89 stream = listener.accept() => stream,
90 _ = recv.changed() => break,
91 };
92
93 let (stream, remote_addr) = match stream {
94 Ok((stream, remote_addr)) => (stream, remote_addr),
95 Err(e) => {
96 tracing::warn!(?e, "Failed to accept connection.");
97 continue;
98 }
99 };
100 let remote_ip = remote_addr.ip();
101 let service = WrappedService::new(service.clone(), remote_ip, "https");
102 let tls_acceptor = tls_acceptor.clone();
103
104 let graceful_shutdown = graceful_shutdown.clone();
105 tokio::spawn(async move {
106 let server = ServerBuilder::new(TokioExecutor::new());
107
108 let stream = match tls_acceptor.accept(stream).await {
109 Ok(stream) => stream,
110 Err(e) => {
111 tracing::warn!(?e, "Failed to accept TLS connection.");
112 return;
113 }
114 };
115 let io = TokioIo::new(stream);
116
117 let conn = server.serve_connection_with_upgrades(io, service);
118 let conn = graceful_shutdown.watch(conn.into_owned());
119
120 if let Err(e) = conn.await {
121 tracing::warn!(?e, "Failed to serve connection.");
122 }
123 });
124 }
125}
126
127pub enum HttpsConfig {
128 Http,
129 Https {
130 resolver: Arc<dyn ResolvesServerCert>,
131 },
132}
133
134impl HttpsConfig {
135 pub fn from_resolver<R: ResolvesServerCert + 'static>(resolver: R) -> Self {
136 Self::Https {
137 resolver: Arc::new(resolver),
138 }
139 }
140
141 pub fn http() -> Self {
142 Self::Http
143 }
144}
145
146impl SimpleHttpServer {
147 pub fn new<S>(service: S, listener: TcpListener, https_config: HttpsConfig) -> Result<Self>
148 where
149 S: Service<Request<Incoming>, Response = Response<SimpleBody>> + Clone + Send + 'static,
150 S::Future: Send + 'static,
151 S::Error: Into<Box<dyn std::error::Error + Send + Sync>>,
152 {
153 let graceful_shutdown = GracefulShutdown::new();
154
155 let handle = match https_config {
156 HttpsConfig::Http => {
157 tokio::spawn(listen_loop(listener, service, graceful_shutdown.clone()))
158 }
159 HttpsConfig::Https { resolver } => {
160 if rustls::crypto::ring::default_provider()
161 .install_default()
162 .is_err()
163 {
164 tracing::info!("Using already-installed crypto provider.")
165 }
166
167 tokio::spawn(listen_loop_tls(
168 listener,
169 service,
170 resolver,
171 graceful_shutdown.clone(),
172 ))
173 }
174 };
175
176 Ok(Self {
177 handle,
178 graceful_shutdown: Some(graceful_shutdown),
179 })
180 }
181
182 pub async fn graceful_shutdown(mut self) {
183 println!("Shutting down");
184 let graceful_shutdown = self
185 .graceful_shutdown
186 .take()
187 .expect("self.graceful_shutdown is always set");
188 graceful_shutdown.shutdown().await;
189 }
190
191 pub async fn graceful_shutdown_with_timeout(mut self, timeout: Duration) {
192 let graceful_shutdown = self
193 .graceful_shutdown
194 .take()
195 .expect("self.graceful_shutdown is always set");
196 let result = tokio::time::timeout(timeout, graceful_shutdown.shutdown()).await;
197
198 if let Err(e) = result {
199 tracing::warn!(?e, "Timed out waiting for graceful shutdown, aborting.");
200 }
201 }
202}
203
204impl Drop for SimpleHttpServer {
205 fn drop(&mut self) {
206 if self.graceful_shutdown.is_some() {
207 tracing::warn!("Shutting down SimpleHttpServer without a call to graceful_shutdown. Connections will be dropped abruptly!");
208 }
209
210 self.handle.abort();
211 }
212}
213
214pub struct ServerWithHttpRedirect {
215 http_server: SimpleHttpServer,
216 https_server: Option<SimpleHttpServer>,
217}
218
219pub struct ServerWithHttpRedirectHttpsConfig {
220 pub https_port: u16,
221 pub resolver: Arc<dyn ResolvesServerCert>,
222}
223
224pub struct ServerWithHttpRedirectConfig {
225 pub http_port: u16,
226 pub https_config: Option<ServerWithHttpRedirectHttpsConfig>,
227}
228
229impl ServerWithHttpRedirect {
230 pub async fn new<S>(service: S, server_config: ServerWithHttpRedirectConfig) -> Result<Self>
231 where
232 S: Service<Request<Incoming>, Response = Response<SimpleBody>>
233 + Clone
234 + Send
235 + Sync
236 + 'static,
237 S::Future: Send + 'static,
238 S::Error: Into<Box<dyn std::error::Error + Send + Sync>>,
239 {
240 if let Some(https_config) = server_config.https_config {
241 let https_listener =
243 TcpListener::bind(SocketAddr::from(([0, 0, 0, 0], https_config.https_port)))
244 .await?;
245 let https_server = SimpleHttpServer::new(
246 service,
247 https_listener,
248 HttpsConfig::Https {
249 resolver: https_config.resolver,
250 },
251 )?;
252
253 let http_listener =
255 TcpListener::bind(SocketAddr::from(([0, 0, 0, 0], server_config.http_port)))
256 .await?;
257 let http_server =
258 SimpleHttpServer::new(HttpsRedirectService, http_listener, HttpsConfig::Http)?;
259
260 Ok(Self {
261 http_server,
262 https_server: Some(https_server),
263 })
264 } else {
265 let listener =
266 TcpListener::bind(SocketAddr::from(([0, 0, 0, 0], server_config.http_port)))
267 .await?;
268 let http_server = SimpleHttpServer::new(service, listener, HttpsConfig::Http)?;
269
270 Ok(Self {
271 http_server,
272 https_server: None,
273 })
274 }
275 }
276
277 pub async fn graceful_shutdown_with_timeout(self, timeout: Duration) {
278 if let Some(https_server) = self.https_server {
279 tokio::join!(
280 self.http_server.graceful_shutdown_with_timeout(timeout),
281 https_server.graceful_shutdown_with_timeout(timeout)
282 );
283 } else {
284 self.http_server
285 .graceful_shutdown_with_timeout(timeout)
286 .await;
287 }
288 }
289}
290
291struct WrappedService<S> {
294 inner: S,
295 forwarded_for: IpAddr,
296 forwarded_proto: &'static str,
297}
298
299impl<S> WrappedService<S> {
300 pub fn new(inner: S, forwarded_for: IpAddr, forwarded_proto: &'static str) -> Self {
301 Self {
302 inner,
303 forwarded_for,
304 forwarded_proto,
305 }
306 }
307}
308
309impl<S, ReqBody, ResBody> Service<Request<ReqBody>> for WrappedService<S>
310where
311 S: Service<Request<ReqBody>, Response = Response<ResBody>>,
312{
313 type Response = S::Response;
314 type Error = S::Error;
315 type Future = S::Future;
316
317 fn call(&self, request: Request<ReqBody>) -> Self::Future {
318 let mut request = request;
319 request.headers_mut().insert(
320 X_FORWARDED_FOR,
321 HeaderValue::from_str(&format!("{}", self.forwarded_for))
322 .expect("X-Forwarded-For is always valid"),
323 );
324 request.headers_mut().insert(
325 X_FORWARDED_PROTO,
326 HeaderValue::from_str(self.forwarded_proto).expect("X-Forwarded-Proto is always valid"),
327 );
328 self.inner.call(request)
329 }
330}