1use super::api_description::ApiDescription;
5use super::body::Body;
6use super::compression::add_vary_header;
7use super::compression::apply_gzip_compression;
8use super::compression::is_compressible_content_type;
9use super::compression::should_compress_response;
10use super::config::{CompressionConfig, ConfigDropshot, ConfigTls};
11#[cfg(feature = "usdt-probes")]
12use super::dtrace::probes;
13use super::handler::HandlerError;
14use super::handler::RequestContext;
15use super::http_util::HEADER_REQUEST_ID;
16use super::router::HttpRouter;
17use super::versioning::VersionPolicy;
18use super::ProbeRegistration;
19
20use async_stream::stream;
21use debug_ignore::DebugIgnore;
22use futures::future::{
23 BoxFuture, FusedFuture, FutureExt, Shared, TryFutureExt,
24};
25use futures::lock::Mutex;
26use futures::stream::{Stream, StreamExt};
27use hyper::service::Service;
28use hyper::Request;
29use hyper::Response;
30use rustls;
31use scopeguard::{guard, ScopeGuard};
32use std::convert::TryFrom;
33use std::future::Future;
34use std::mem;
35use std::net::SocketAddr;
36use std::num::NonZeroU32;
37use std::panic;
38use std::pin::Pin;
39use std::sync::Arc;
40use std::task::{Context, Poll};
41use tokio::io::ReadBuf;
42use tokio::net::{TcpListener, TcpStream};
43use tokio::sync::oneshot;
44use tokio_rustls::{server::TlsStream, TlsAcceptor};
45use uuid::Uuid;
46use waitgroup::WaitGroup;
47
48use crate::config::HandlerTaskMode;
49use crate::RequestInfo;
50use slog::Logger;
51use thiserror::Error;
52
53type GenericError = Box<dyn std::error::Error + Send + Sync>;
55
56pub trait ServerContext: Send + Sync + 'static {}
60
61impl<T: 'static> ServerContext for T where T: Send + Sync {}
62
63#[derive(Debug)]
65pub struct DropshotState<C: ServerContext> {
66 pub private: C,
68 pub config: ServerConfig,
70 pub router: HttpRouter<C>,
72 pub log: Logger,
74 pub local_addr: SocketAddr,
76 pub(crate) tls_acceptor: Option<Arc<Mutex<TlsAcceptor>>>,
78 pub(crate) handler_waitgroup_worker: DebugIgnore<waitgroup::Worker>,
81 pub(crate) version_policy: VersionPolicy,
83}
84
85impl<C: ServerContext> DropshotState<C> {
86 pub fn using_tls(&self) -> bool {
87 self.tls_acceptor.is_some()
88 }
89}
90
91#[derive(Debug)]
94pub struct ServerConfig {
95 pub default_request_body_max_bytes: usize,
97 pub page_max_nitems: NonZeroU32,
99 pub page_default_nitems: NonZeroU32,
101 pub default_handler_task_mode: HandlerTaskMode,
104 pub log_headers: Vec<String>,
112 pub compression: CompressionConfig,
114}
115
116pub struct HttpServerStarter<C: ServerContext> {
121 app_state: Arc<DropshotState<C>>,
122 local_addr: SocketAddr,
123 handler_waitgroup: WaitGroup,
124 http_acceptor: HttpAcceptor,
125 tls_acceptor: Option<Arc<Mutex<TlsAcceptor>>>,
126}
127
128impl<C: ServerContext> HttpServerStarter<C> {
129 pub fn new(
134 config: &ConfigDropshot,
135 api: ApiDescription<C>,
136 private: C,
137 log: &Logger,
138 ) -> Result<HttpServerStarter<C>, GenericError> {
139 HttpServerStarter::new_with_tls(config, api, private, log, None)
140 }
141
142 pub fn new_with_tls(
147 config: &ConfigDropshot,
148 api: ApiDescription<C>,
149 private: C,
150 log: &Logger,
151 tls: Option<ConfigTls>,
152 ) -> Result<HttpServerStarter<C>, GenericError> {
153 ServerBuilder::new(api, private, log.clone())
154 .config(config.clone())
155 .tls(tls)
156 .build_starter()
157 .map_err(|e| Box::new(e) as GenericError)
158 }
159
160 fn new_internal(
161 config: &ConfigDropshot,
162 api: ApiDescription<C>,
163 private: C,
164 log: &Logger,
165 tls: Option<ConfigTls>,
166 version_policy: VersionPolicy,
167 ) -> Result<HttpServerStarter<C>, BuildError> {
168 let tcp = {
169 let std_listener = std::net::TcpListener::bind(
170 &config.bind_address,
171 )
172 .map_err(|e| BuildError::bind_error(e, config.bind_address))?;
173 std_listener.set_nonblocking(true).map_err(|e| {
174 BuildError::generic_system(e, "setting non-blocking")
175 })?;
176 TcpListener::from_std(std_listener).map_err(|e| {
179 BuildError::generic_system(e, "creating TCP listener")
180 })?
181 };
182
183 let local_addr = tcp.local_addr().map_err(|e| {
184 BuildError::generic_system(e, "getting local TCP address")
185 })?;
186
187 let log = log.new(o!("local_addr" => local_addr));
188
189 let server_config = ServerConfig {
190 default_request_body_max_bytes: config
192 .default_request_body_max_bytes,
193 page_max_nitems: NonZeroU32::new(10000).unwrap(),
194 page_default_nitems: NonZeroU32::new(100).unwrap(),
195 default_handler_task_mode: config.default_handler_task_mode,
196 log_headers: config.log_headers.clone(),
197 compression: config.compression,
198 };
199
200 let tls_acceptor = tls
201 .as_ref()
202 .map(|tls| {
203 Ok(Arc::new(Mutex::new(TlsAcceptor::from(Arc::new(
204 rustls::ServerConfig::try_from(tls)?,
205 )))))
206 })
207 .transpose()?;
208 let handler_waitgroup = WaitGroup::new();
209
210 let router = api.into_router();
211 if let VersionPolicy::Unversioned = version_policy {
212 if router.has_versioned_routes() {
213 return Err(BuildError::UnversionedServerHasVersionedRoutes);
214 }
215 }
216
217 let app_state = Arc::new(DropshotState {
218 private,
219 config: server_config,
220 router,
221 log: log.clone(),
222 local_addr,
223 tls_acceptor: tls_acceptor.clone(),
224 handler_waitgroup_worker: DebugIgnore(handler_waitgroup.worker()),
225 version_policy,
226 });
227
228 for (path, method, endpoint) in app_state.router.endpoints(None) {
229 debug!(&log, "registered endpoint";
230 "method" => &method,
231 "path" => &path,
232 "versions" => &endpoint.versions,
233 );
234 }
235
236 let http_acceptor = HttpAcceptor { tcp, log: log.clone() };
237
238 Ok(HttpServerStarter {
239 app_state,
240 local_addr,
241 handler_waitgroup,
242 http_acceptor,
243 tls_acceptor,
244 })
245 }
246
247 pub fn start(self) -> HttpServer<C> {
248 let HttpServerStarter {
249 app_state,
250 local_addr,
251 handler_waitgroup,
252 tls_acceptor,
253 http_acceptor,
254 } = self;
255
256 let (tx, mut rx) = tokio::sync::oneshot::channel::<()>();
257 let make_service = ServerConnectionHandler::new(Arc::clone(&app_state));
258 let log = &app_state.log;
259 let log_close = log.clone();
260 let join_handle = tokio::spawn(async move {
261 use hyper_util::rt::{TokioExecutor, TokioIo, TokioTimer};
262 use hyper_util::server::conn::auto;
263
264 let mut builder = auto::Builder::new(TokioExecutor::new());
265 builder.http1().timer(TokioTimer::new());
267 builder.http2().timer(TokioTimer::new());
269
270 let graceful =
274 hyper_util::server::graceful::GracefulShutdown::new();
275
276 let log = log_close;
280 match tls_acceptor {
281 Some(tls_acceptor) => {
282 let mut https_acceptor = HttpsAcceptor::new(
283 log.clone(),
284 tls_acceptor,
285 http_acceptor,
286 );
287 loop {
288 tokio::select! {
289 Some(Ok(sock)) = https_acceptor.accept() => {
290 let remote_addr = sock.remote_addr();
291 let handler = make_service
292 .make_http_request_handler(remote_addr);
293 let fut = builder
294 .serve_connection_with_upgrades(
295 TokioIo::new(sock),
296 handler,
297 );
298 let fut = graceful.watch(fut.into_owned());
299 tokio::spawn(fut);
300 },
301
302 _ = &mut rx => {
303 info!(log, "beginning graceful shutdown");
304 break;
305 }
306 }
307 }
308 }
309 None => loop {
310 tokio::select! {
311 (sock, remote_addr) = http_acceptor.accept() => {
312 let handler = make_service
313 .make_http_request_handler(remote_addr);
314 let fut = builder
315 .serve_connection_with_upgrades(
316 TokioIo::new(sock),
317 handler,
318 );
319 let fut = graceful.watch(fut.into_owned());
320 tokio::spawn(fut);
321 },
322
323 _ = &mut rx => {
324 info!(log, "beginning graceful shutdown");
325 break;
326 }
327 }
328 },
329 };
330
331 graceful.shutdown().await
333 });
334
335 info!(log, "listening");
336
337 let join_handle = async move {
338 () = join_handle
341 .await
342 .map_err(|e| format!("server stopped: {e}"))?;
343 () = handler_waitgroup.wait().await;
344 Ok(())
345 };
346
347 #[cfg(feature = "usdt-probes")]
348 let probe_registration = match usdt::register_probes() {
349 Ok(_) => {
350 debug!(&log, "successfully registered DTrace USDT probes");
351 ProbeRegistration::Succeeded
352 }
353 Err(e) => {
354 let msg = e.to_string();
355 error!(&log, "failed to register DTrace USDT probes: {}", msg);
356 ProbeRegistration::Failed(msg)
357 }
358 };
359 #[cfg(not(feature = "usdt-probes"))]
360 let probe_registration = {
361 debug!(&log, "DTrace USDT probes compiled out, not registering");
362 ProbeRegistration::Disabled
363 };
364
365 HttpServer {
366 probe_registration,
367 app_state,
368 local_addr,
369 closer: CloseHandle { close_channel: Some(tx) },
370 join_future: join_handle.boxed().shared(),
371 }
372 }
373}
374
375struct HttpAcceptor {
378 tcp: TcpListener,
379 log: slog::Logger,
380}
381
382impl HttpAcceptor {
383 async fn accept(&self) -> (TcpStream, SocketAddr) {
384 loop {
385 match self.tcp.accept().await {
386 Ok((socket, addr)) => return (socket, addr),
387 Err(e) => match e.kind() {
388 std::io::ErrorKind::ConnectionRefused
391 | std::io::ErrorKind::ConnectionAborted
392 | std::io::ErrorKind::ConnectionReset => (),
393
394 _ => {
397 warn!(self.log, "accept error"; "error" => e);
398 tokio::time::sleep(std::time::Duration::from_millis(
399 100,
400 ))
401 .await;
402 }
403 },
404 }
405 }
406 }
407}
408
409#[derive(Debug)]
411struct TlsConn {
412 stream: TlsStream<TcpStream>,
413 remote_addr: SocketAddr,
414}
415
416impl TlsConn {
417 fn new(stream: TlsStream<TcpStream>, remote_addr: SocketAddr) -> TlsConn {
418 TlsConn { stream, remote_addr }
419 }
420
421 fn remote_addr(&self) -> SocketAddr {
422 self.remote_addr
423 }
424}
425
426impl tokio::io::AsyncRead for TlsConn {
428 fn poll_read(
429 mut self: Pin<&mut Self>,
430 ctx: &mut core::task::Context,
431 buf: &mut ReadBuf,
432 ) -> Poll<std::io::Result<()>> {
433 let pinned = Pin::new(&mut self.stream);
434 pinned.poll_read(ctx, buf)
435 }
436}
437
438impl tokio::io::AsyncWrite for TlsConn {
440 fn poll_write(
441 mut self: Pin<&mut Self>,
442 ctx: &mut core::task::Context,
443 data: &[u8],
444 ) -> Poll<std::io::Result<usize>> {
445 let pinned = Pin::new(&mut self.stream);
446 pinned.poll_write(ctx, data)
447 }
448
449 fn poll_flush(
450 mut self: Pin<&mut Self>,
451 ctx: &mut core::task::Context,
452 ) -> Poll<std::io::Result<()>> {
453 let pinned = Pin::new(&mut self.stream);
454 pinned.poll_flush(ctx)
455 }
456
457 fn poll_shutdown(
458 mut self: Pin<&mut Self>,
459 ctx: &mut core::task::Context,
460 ) -> Poll<std::io::Result<()>> {
461 let pinned = Pin::new(&mut self.stream);
462 pinned.poll_shutdown(ctx)
463 }
464}
465
466struct HttpsAcceptor {
475 stream: Box<dyn Stream<Item = std::io::Result<TlsConn>> + Send + Unpin>,
476}
477
478impl HttpsAcceptor {
479 pub fn new(
480 log: slog::Logger,
481 tls_acceptor: Arc<Mutex<TlsAcceptor>>,
482 http_acceptor: HttpAcceptor,
483 ) -> HttpsAcceptor {
484 HttpsAcceptor {
485 stream: Box::new(Box::pin(Self::new_stream(
486 log,
487 tls_acceptor,
488 http_acceptor,
489 ))),
490 }
491 }
492
493 async fn accept(&mut self) -> Option<std::io::Result<TlsConn>> {
494 self.stream.next().await
495 }
496
497 fn new_stream(
498 log: slog::Logger,
499 tls_acceptor: Arc<Mutex<TlsAcceptor>>,
500 http_acceptor: HttpAcceptor,
501 ) -> impl Stream<Item = std::io::Result<TlsConn>> {
502 stream! {
503 let mut tls_negotiations = futures::stream::FuturesUnordered::new();
504 loop {
505 tokio::select! {
506 Some(negotiation) = tls_negotiations.next(), if
507 !tls_negotiations.is_empty() => {
508
509 match negotiation {
510 Ok(conn) => yield Ok(conn),
511 Err(e) => {
512 warn!(log, "tls accept err: {}", e);
523 },
524 }
525 },
526 (socket, addr) = http_acceptor.accept() => {
527 let tls_negotiation = tls_acceptor
528 .lock()
529 .await
530 .accept(socket)
531 .map_ok(move |stream| TlsConn::new(stream, addr));
532 tls_negotiations.push(tls_negotiation);
533 },
534 else => break,
535 }
536 }
537 }
538 }
539}
540
541impl TryFrom<&ConfigTls> for rustls::ServerConfig {
543 type Error = BuildError;
544
545 fn try_from(config: &ConfigTls) -> Result<Self, Self::Error> {
546 let (mut cert_reader, mut key_reader): (
547 Box<dyn std::io::BufRead>,
548 Box<dyn std::io::BufRead>,
549 ) = match config {
550 ConfigTls::Dynamic(raw) => {
551 return Ok(raw.clone());
552 }
553 ConfigTls::AsBytes { certs, key } => (
554 Box::new(std::io::BufReader::new(certs.as_slice())),
555 Box::new(std::io::BufReader::new(key.as_slice())),
556 ),
557 ConfigTls::AsFile { cert_file, key_file } => {
558 let certfile = Box::new(std::io::BufReader::new(
559 std::fs::File::open(cert_file).map_err(|e| {
560 BuildError::generic_system(
561 e,
562 format!("opening {}", cert_file.display()),
563 )
564 })?,
565 ));
566 let keyfile = Box::new(std::io::BufReader::new(
567 std::fs::File::open(key_file).map_err(|e| {
568 BuildError::generic_system(
569 e,
570 format!("opening {}", key_file.display()),
571 )
572 })?,
573 ));
574 (certfile, keyfile)
575 }
576 };
577
578 let certs = rustls_pemfile::certs(&mut cert_reader)
579 .collect::<Result<Vec<_>, _>>()
580 .map_err(|err| {
581 BuildError::generic_system(err, "loading TLS certificates")
582 })?;
583 let keys = rustls_pemfile::pkcs8_private_keys(&mut key_reader)
584 .collect::<Result<Vec<_>, _>>()
585 .map_err(|err| {
586 BuildError::generic_system(err, "loading TLS private key")
587 })?;
588 let mut keys_iter = keys.into_iter();
589 let (Some(private_key), None) = (keys_iter.next(), keys_iter.next())
590 else {
591 return Err(BuildError::NotOnePrivateKey);
592 };
593
594 let mut cfg = rustls::ServerConfig::builder()
595 .with_no_client_auth()
596 .with_single_cert(certs, private_key.into())
597 .expect("bad certificate/key");
598 cfg.alpn_protocols = vec![b"h2".to_vec(), b"http/1.1".to_vec()];
599 Ok(cfg)
600 }
601}
602
603type SharedBoxFuture<T> = Shared<Pin<Box<dyn Future<Output = T> + Send>>>;
604
605pub struct ShutdownWaitFuture(SharedBoxFuture<Result<(), String>>);
607
608impl Future for ShutdownWaitFuture {
609 type Output = Result<(), String>;
610
611 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
612 Pin::new(&mut self.get_mut().0).poll(cx)
613 }
614}
615
616impl FusedFuture for ShutdownWaitFuture {
617 fn is_terminated(&self) -> bool {
618 self.0.is_terminated()
619 }
620}
621
622pub struct HttpServer<C: ServerContext> {
627 probe_registration: ProbeRegistration,
628 app_state: Arc<DropshotState<C>>,
629 local_addr: SocketAddr,
630 closer: CloseHandle,
631 join_future: SharedBoxFuture<Result<(), String>>,
632}
633
634struct CloseHandle {
636 close_channel: Option<tokio::sync::oneshot::Sender<()>>,
637}
638
639impl<C: ServerContext> HttpServer<C> {
640 pub fn local_addr(&self) -> SocketAddr {
641 self.local_addr
642 }
643
644 pub fn app_private(&self) -> &C {
645 &self.app_state.private
646 }
647
648 pub fn using_tls(&self) -> bool {
649 self.app_state.using_tls()
650 }
651
652 pub async fn refresh_tls(&self, config: &ConfigTls) -> Result<(), String> {
654 let acceptor = &self
655 .app_state
656 .tls_acceptor
657 .as_ref()
658 .ok_or_else(|| "Not configured for TLS".to_string())?;
659
660 *acceptor.lock().await = TlsAcceptor::from(Arc::new(
661 rustls::ServerConfig::try_from(config).unwrap(),
662 ));
663 Ok(())
664 }
665
666 pub fn probe_registration(&self) -> &ProbeRegistration {
670 &self.probe_registration
671 }
672
673 pub fn wait_for_shutdown(&self) -> ShutdownWaitFuture {
681 ShutdownWaitFuture(self.join_future.clone())
682 }
683
684 pub async fn close(mut self) -> Result<(), String> {
686 self.closer
687 .close_channel
688 .take()
689 .expect("cannot close twice")
690 .send(())
691 .expect("failed to send close signal");
692
693 mem::drop(self.app_state);
699
700 self.join_future.await
701 }
702}
703
704impl Drop for CloseHandle {
709 fn drop(&mut self) {
710 if let Some(c) = self.close_channel.take() {
711 let _ = c.send(());
716 }
717 }
718}
719
720impl<C: ServerContext> Future for HttpServer<C> {
721 type Output = Result<(), String>;
722
723 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
724 let server = Pin::into_inner(self);
725 let join_future = Pin::new(&mut server.join_future);
726 join_future.poll(cx)
727 }
728}
729
730impl<C: ServerContext> FusedFuture for HttpServer<C> {
731 fn is_terminated(&self) -> bool {
732 self.join_future.is_terminated()
733 }
734}
735
736async fn http_request_handle_wrap<C: ServerContext>(
741 server: Arc<DropshotState<C>>,
742 remote_addr: SocketAddr,
743 request: Request<hyper::body::Incoming>,
744) -> Result<Response<Body>, GenericError> {
745 let start_time = std::time::Instant::now();
750 let request_id = generate_request_id();
751
752 let mut request_log = server.log.new(o!(
753 "remote_addr" => remote_addr,
754 "req_id" => request_id.clone(),
755 "method" => request.method().as_str().to_string(),
756 "uri" => format!("{}", request.uri()),
757 ));
758 for name in server.config.log_headers.iter() {
761 let v = request
762 .headers()
763 .get(name)
764 .and_then(|v| v.to_str().ok().map(str::to_string));
765
766 if let Some(v) = v {
767 let k = format!("hdr_{}", name.to_lowercase().replace('-', "_"));
780 request_log = request_log.new(o!(k => v));
781 }
782 }
783
784 trace!(request_log, "incoming request");
785 #[cfg(feature = "usdt-probes")]
786 probes::request__start!(|| {
787 let uri = request.uri();
788 crate::dtrace::RequestInfo {
789 id: request_id.clone(),
790 local_addr: server.local_addr,
791 remote_addr,
792 method: request.method().to_string(),
793 path: uri.path().to_string(),
794 query: uri.query().map(|x| x.to_string()),
795 }
796 });
797
798 #[cfg(feature = "usdt-probes")]
801 let local_addr = server.local_addr;
802
803 let on_disconnect = guard((), |_| {
806 let latency_us = start_time.elapsed().as_micros();
807
808 warn!(request_log, "request handling cancelled (client disconnected)";
809 "latency_us" => latency_us,
810 );
811
812 #[cfg(feature = "usdt-probes")]
813 probes::request__done!(|| {
814 crate::dtrace::ResponseInfo {
815 id: request_id.clone(),
816 local_addr,
817 remote_addr,
818 status_code: 499,
820 message: String::from(
821 "client disconnected before response returned",
822 ),
823 }
824 });
825 });
826
827 let maybe_response = http_request_handle(
828 server,
829 request,
830 &request_id,
831 request_log.new(o!()),
832 remote_addr,
833 )
834 .await;
835
836 let _ = ScopeGuard::into_inner(on_disconnect);
839
840 let latency_us = start_time.elapsed().as_micros();
841 let response = match maybe_response {
842 Err(error) => {
843 {
844 let status = error.status_code();
845 let message_external = error.external_message();
846 let message_internal = error.internal_message();
847
848 #[cfg(feature = "usdt-probes")]
849 probes::request__done!(|| {
850 crate::dtrace::ResponseInfo {
851 id: request_id.clone(),
852 local_addr,
853 remote_addr,
854 status_code: status.as_u16(),
855 message: message_external
856 .cloned()
857 .unwrap_or_else(|| message_internal.clone()),
858 }
859 });
860
861 info!(request_log, "request completed";
863 "response_code" => status.as_u16(),
864 "latency_us" => latency_us,
865 "error_message_internal" => message_internal,
866 "error_message_external" => message_external,
867 );
868 };
869 error.into_response(&request_id)
870 }
871
872 Ok(response) => {
873 info!(request_log, "request completed";
875 "response_code" => response.status().as_u16(),
876 "latency_us" => latency_us,
877 );
878
879 #[cfg(feature = "usdt-probes")]
880 probes::request__done!(|| {
881 crate::dtrace::ResponseInfo {
882 id: request_id.parse().unwrap(),
883 local_addr,
884 remote_addr,
885 status_code: response.status().as_u16(),
886 message: "".to_string(),
887 }
888 });
889
890 response
891 }
892 };
893
894 Ok(response)
895}
896
897async fn http_request_handle<C: ServerContext>(
898 server: Arc<DropshotState<C>>,
899 request: Request<hyper::body::Incoming>,
900 request_id: &str,
901 request_log: Logger,
902 remote_addr: std::net::SocketAddr,
903) -> Result<Response<Body>, HandlerError> {
904 let request = request.map(crate::Body::wrap);
911 let method = request.method().clone();
912 let uri = request.uri();
913 let found_version =
914 server.version_policy.request_version(&request, &request_log)?;
915 let lookup_result = server.router.lookup_route(
916 &method,
917 uri.path().into(),
918 found_version.as_ref(),
919 )?;
920 let rqctx = RequestContext {
921 server: Arc::clone(&server),
922 request: RequestInfo::new(&request, remote_addr),
923 endpoint: lookup_result.endpoint,
924 request_id: request_id.to_string(),
925 log: request_log.clone(),
926 };
927 let request_headers = rqctx.request.headers().clone();
928 let handler = lookup_result.handler;
929
930 let mut response = match server.config.default_handler_task_mode {
931 HandlerTaskMode::CancelOnDisconnect => {
932 handler.handle_request(rqctx, request).await?
936 }
937 HandlerTaskMode::Detached => {
938 let (tx, rx) = oneshot::channel();
941 let request_log = request_log.clone();
942 let worker = server.handler_waitgroup_worker.clone();
943 let handler_task = tokio::spawn(async move {
944 let request_log = rqctx.log.clone();
945 let result = handler.handle_request(rqctx, request).await;
946
947 if let Err(result) = tx.send(result) {
950 match result {
951 Ok(r) => warn!(
952 request_log, "request completed after handler was already cancelled";
953 "response_code" => r.status().as_u16(),
954 ),
955 Err(error) => {
956 warn!(request_log, "request completed after handler was already cancelled";
957 "response_code" => error.status_code().as_u16(),
958 "error_message_internal" => error.internal_message(),
959 "error_message_external" => error.external_message(),
960 );
961 }
962 }
963 }
964
965 mem::drop(worker);
968 });
969
970 match rx.await {
976 Ok(result) => result?,
977 Err(_) => {
978 error!(request_log, "handler panicked; propagating panic");
979
980 let task_err = handler_task.await.expect_err(
985 "task failed to send result but didn't panic",
986 );
987 panic::resume_unwind(task_err.into_panic());
988 }
989 }
990 }
991 };
992
993 if matches!(server.config.compression, CompressionConfig::Gzip)
994 && is_compressible_content_type(response.headers())
995 {
996 add_vary_header(response.headers_mut());
1004
1005 if should_compress_response(
1006 &method,
1007 &request_headers,
1008 response.status(),
1009 response.headers(),
1010 response.extensions(),
1011 ) {
1012 response = apply_gzip_compression(response);
1013 }
1014 }
1015
1016 response.headers_mut().insert(
1017 HEADER_REQUEST_ID,
1018 http::header::HeaderValue::from_str(&request_id).unwrap(),
1019 );
1020 Ok(response)
1021}
1022
1023fn generate_request_id() -> String {
1029 format!("{}", Uuid::new_v4())
1030}
1031
1032pub struct ServerConnectionHandler<C: ServerContext> {
1038 server: Arc<DropshotState<C>>,
1040}
1041
1042impl<C: ServerContext> ServerConnectionHandler<C> {
1043 fn new(server: Arc<DropshotState<C>>) -> Self {
1046 ServerConnectionHandler { server }
1047 }
1048
1049 fn make_http_request_handler(
1054 &self,
1055 remote_addr: SocketAddr,
1056 ) -> ServerRequestHandler<C> {
1057 info!(self.server.log, "accepted connection"; "remote_addr" => %remote_addr);
1058 ServerRequestHandler::new(self.server.clone(), remote_addr)
1059 }
1060}
1061
1062pub struct ServerRequestHandler<C: ServerContext> {
1068 server: Arc<DropshotState<C>>,
1070 remote_addr: SocketAddr,
1071}
1072
1073impl<C: ServerContext> ServerRequestHandler<C> {
1074 fn new(server: Arc<DropshotState<C>>, remote_addr: SocketAddr) -> Self {
1077 ServerRequestHandler { server, remote_addr }
1078 }
1079}
1080
1081impl<C: ServerContext> Service<Request<hyper::body::Incoming>>
1082 for ServerRequestHandler<C>
1083{
1084 type Response = Response<Body>;
1085 type Error = GenericError;
1086 type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;
1087
1088 fn call(&self, req: Request<hyper::body::Incoming>) -> Self::Future {
1089 Box::pin(http_request_handle_wrap(
1090 Arc::clone(&self.server),
1091 self.remote_addr,
1092 req,
1093 ))
1094 }
1095}
1096
1097#[derive(Debug, Error)]
1099pub enum BuildError {
1100 #[error("failed to bind to {address}")]
1101 BindError {
1102 address: SocketAddr,
1103 #[source]
1104 error: std::io::Error,
1105 },
1106 #[error("expected exactly one TLS private key")]
1107 NotOnePrivateKey,
1108 #[error("{context}")]
1109 SystemError {
1110 context: String,
1111 #[source]
1112 error: std::io::Error,
1113 },
1114 #[error(
1115 "unversioned servers cannot have endpoints with specific versions"
1116 )]
1117 UnversionedServerHasVersionedRoutes,
1118}
1119
1120impl BuildError {
1121 fn bind_error(error: std::io::Error, address: SocketAddr) -> BuildError {
1123 BuildError::BindError { address, error }
1124 }
1125
1126 fn generic_system<S: Into<String>>(
1131 error: std::io::Error,
1132 context: S,
1133 ) -> BuildError {
1134 BuildError::SystemError { context: context.into(), error }
1135 }
1136}
1137
1138#[derive(Debug)]
1140pub struct ServerBuilder<C: ServerContext> {
1141 private: C,
1143 log: Logger,
1144 api: DebugIgnore<ApiDescription<C>>,
1145
1146 config: ConfigDropshot,
1148 version_policy: VersionPolicy,
1149 tls: Option<ConfigTls>,
1150}
1151
1152impl<C: ServerContext> ServerBuilder<C> {
1153 pub fn new(
1160 api: ApiDescription<C>,
1161 private: C,
1162 log: Logger,
1163 ) -> ServerBuilder<C> {
1164 ServerBuilder {
1165 private,
1166 log,
1167 api: DebugIgnore(api),
1168 config: Default::default(),
1169 version_policy: VersionPolicy::Unversioned,
1170 tls: Default::default(),
1171 }
1172 }
1173
1174 pub fn config(mut self, config: ConfigDropshot) -> Self {
1176 self.config = config;
1177 self
1178 }
1179
1180 pub fn tls(mut self, tls: Option<ConfigTls>) -> Self {
1185 self.tls = tls;
1186 self
1187 }
1188
1189 pub fn version_policy(mut self, version_policy: VersionPolicy) -> Self {
1195 self.version_policy = version_policy;
1196 self
1197 }
1198
1199 pub fn start(self) -> Result<HttpServer<C>, BuildError> {
1205 Ok(self.build_starter()?.start())
1206 }
1207
1208 pub fn build_starter(self) -> Result<HttpServerStarter<C>, BuildError> {
1224 HttpServerStarter::new_internal(
1225 &self.config,
1226 self.api.0,
1227 self.private,
1228 &self.log,
1229 self.tls,
1230 self.version_policy,
1231 )
1232 }
1233}
1234
1235#[cfg(test)]
1236mod test {
1237 use super::*;
1238 use crate as dropshot;
1241 use dropshot::endpoint;
1242 use dropshot::test_util::ClientTestContext;
1243 use dropshot::test_util::LogContext;
1244 use dropshot::ConfigLogging;
1245 use dropshot::ConfigLoggingLevel;
1246 use dropshot::HttpError;
1247 use dropshot::HttpResponseOk;
1248 use dropshot::RequestContext;
1249 use http::StatusCode;
1250 use hyper::Method;
1251
1252 use futures::future::FusedFuture;
1253
1254 #[endpoint {
1255 method = GET,
1256 path = "/handler",
1257 }]
1258 async fn handler(
1259 _rqctx: RequestContext<i32>,
1260 ) -> Result<HttpResponseOk<u64>, HttpError> {
1261 Ok(HttpResponseOk(3))
1262 }
1263
1264 struct TestConfig {
1265 log_context: LogContext,
1266 }
1267
1268 impl TestConfig {
1269 fn log(&self) -> &slog::Logger {
1270 &self.log_context.log
1271 }
1272 }
1273
1274 fn create_test_server() -> (HttpServer<i32>, TestConfig) {
1275 let config_dropshot = ConfigDropshot::default();
1276
1277 let mut api = ApiDescription::new();
1278 api.register(handler).unwrap();
1279
1280 let config_logging =
1281 ConfigLogging::StderrTerminal { level: ConfigLoggingLevel::Warn };
1282 let log_context = LogContext::new("test server", &config_logging);
1283 let log = &log_context.log;
1284
1285 let server = HttpServerStarter::new(&config_dropshot, api, 0, log)
1286 .unwrap()
1287 .start();
1288
1289 (server, TestConfig { log_context })
1290 }
1291
1292 async fn single_client_request(addr: SocketAddr, log: &slog::Logger) {
1293 let client_log = log.new(o!("http_client" => "dropshot test suite"));
1294 let client_testctx = ClientTestContext::new(addr, client_log);
1295 tokio::task::spawn(async move {
1296 let response = client_testctx
1297 .make_request(
1298 Method::GET,
1299 "/handler",
1300 None as Option<()>,
1301 StatusCode::OK,
1302 )
1303 .await;
1304
1305 assert!(response.is_ok());
1306 })
1307 .await
1308 .expect("client request failed");
1309 }
1310
1311 #[tokio::test]
1312 async fn test_server_run_then_close() {
1313 let (mut server, config) = create_test_server();
1314 let client = single_client_request(server.local_addr(), config.log());
1315
1316 futures::select! {
1317 _ = client.fuse() => {},
1318 r = server => panic!("Server unexpectedly terminated: {:?}", r),
1319 }
1320
1321 assert!(!server.is_terminated());
1322 assert!(server.close().await.is_ok());
1323 }
1324
1325 #[tokio::test]
1326 async fn test_drop_server_without_close_okay() {
1327 let (server, _) = create_test_server();
1328 std::mem::drop(server);
1329 }
1330
1331 #[tokio::test]
1332 async fn test_http_acceptor_happy_path() {
1333 const TOTAL: usize = 100;
1334 let tcp =
1335 tokio::net::TcpListener::bind("127.0.0.1:0").await.expect("bind");
1336 let addr = tcp.local_addr().expect("local_addr");
1337 let acceptor =
1338 HttpAcceptor { log: slog::Logger::root(slog::Discard, o!()), tcp };
1339
1340 let t1 = tokio::spawn(async move {
1341 for _ in 0..TOTAL {
1342 let _ = acceptor.accept().await;
1343 }
1344 });
1345
1346 let t2 = tokio::spawn(async move {
1347 for _ in 0..TOTAL {
1348 tokio::net::TcpStream::connect(&addr).await.expect("connect");
1349 }
1350 });
1351
1352 t1.await.expect("task 1");
1353 t2.await.expect("task 2");
1354 }
1355}