Skip to main content

dropshot/
server.rs

1// Copyright 2024 Oxide Computer Company
2//! Generic server-wide state and facilities
3
4use super::api_description::ApiDescription;
5use super::body::Body;
6use super::compression::add_vary_header;
7use super::compression::apply_gzip_compression;
8use super::compression::is_compressible_content_type;
9use super::compression::should_compress_response;
10use super::config::{CompressionConfig, ConfigDropshot, ConfigTls};
11#[cfg(feature = "usdt-probes")]
12use super::dtrace::probes;
13use super::handler::HandlerError;
14use super::handler::RequestContext;
15use super::http_util::HEADER_REQUEST_ID;
16use super::router::HttpRouter;
17use super::versioning::VersionPolicy;
18use super::ProbeRegistration;
19
20use async_stream::stream;
21use debug_ignore::DebugIgnore;
22use futures::future::{
23    BoxFuture, FusedFuture, FutureExt, Shared, TryFutureExt,
24};
25use futures::lock::Mutex;
26use futures::stream::{Stream, StreamExt};
27use hyper::service::Service;
28use hyper::Request;
29use hyper::Response;
30use rustls;
31use scopeguard::{guard, ScopeGuard};
32use std::convert::TryFrom;
33use std::future::Future;
34use std::mem;
35use std::net::SocketAddr;
36use std::num::NonZeroU32;
37use std::panic;
38use std::pin::Pin;
39use std::sync::Arc;
40use std::task::{Context, Poll};
41use tokio::io::ReadBuf;
42use tokio::net::{TcpListener, TcpStream};
43use tokio::sync::oneshot;
44use tokio_rustls::{server::TlsStream, TlsAcceptor};
45use uuid::Uuid;
46use waitgroup::WaitGroup;
47
48use crate::config::HandlerTaskMode;
49use crate::RequestInfo;
50use slog::Logger;
51use thiserror::Error;
52
53// TODO Remove when we can remove `HttpServerStarter`
54type GenericError = Box<dyn std::error::Error + Send + Sync>;
55
56/// Endpoint-accessible context associated with a server.
57///
58/// Automatically implemented for all Send + Sync types.
59pub trait ServerContext: Send + Sync + 'static {}
60
61impl<T: 'static> ServerContext for T where T: Send + Sync {}
62
63/// Stores shared state used by the Dropshot server.
64#[derive(Debug)]
65pub struct DropshotState<C: ServerContext> {
66    /// caller-specific state
67    pub private: C,
68    /// static server configuration parameters
69    pub config: ServerConfig,
70    /// request router
71    pub router: HttpRouter<C>,
72    /// server-wide log handle
73    pub log: Logger,
74    /// bound local address for the server.
75    pub local_addr: SocketAddr,
76    /// Identifies how to accept TLS connections
77    pub(crate) tls_acceptor: Option<Arc<Mutex<TlsAcceptor>>>,
78    /// Worker for the handler_waitgroup associated with this server, allowing
79    /// graceful shutdown to wait for all handlers to complete.
80    pub(crate) handler_waitgroup_worker: DebugIgnore<waitgroup::Worker>,
81    /// specifies how incoming requests are mapped to handlers based on versions
82    pub(crate) version_policy: VersionPolicy,
83}
84
85impl<C: ServerContext> DropshotState<C> {
86    pub fn using_tls(&self) -> bool {
87        self.tls_acceptor.is_some()
88    }
89}
90
91/// Stores static configuration associated with the server
92/// TODO-cleanup merge with ConfigDropshot
93#[derive(Debug)]
94pub struct ServerConfig {
95    /// maximum allowed size of a request body
96    pub default_request_body_max_bytes: usize,
97    /// maximum size of any page of results
98    pub page_max_nitems: NonZeroU32,
99    /// default size for a page of results
100    pub page_default_nitems: NonZeroU32,
101    /// Default behavior for HTTP handler functions with respect to clients
102    /// disconnecting early.
103    pub default_handler_task_mode: HandlerTaskMode,
104    /// A list of header names to include as extra properties in the log
105    /// messages emitted by the per-request logger.  Each header will, if
106    /// present, be included in the output with a "hdr_"-prefixed property name
107    /// in lower case that has all hyphens replaced with underscores; e.g.,
108    /// "X-Forwarded-For" will be included as "hdr_x_forwarded_for".  No attempt
109    /// is made to deal with headers that appear multiple times in a single
110    /// request.
111    pub log_headers: Vec<String>,
112    /// Configuration for response compression.
113    pub compression: CompressionConfig,
114}
115
116/// See [`ServerBuilder`] instead.
117// It would be nice to remove this structure altogether once we've got
118// confidence that no consumers actually need to distinguish between the
119// configuration and start steps.
120pub struct HttpServerStarter<C: ServerContext> {
121    app_state: Arc<DropshotState<C>>,
122    local_addr: SocketAddr,
123    handler_waitgroup: WaitGroup,
124    http_acceptor: HttpAcceptor,
125    tls_acceptor: Option<Arc<Mutex<TlsAcceptor>>>,
126}
127
128impl<C: ServerContext> HttpServerStarter<C> {
129    /// Make an `HttpServerStarter` to start an `HttpServer`
130    ///
131    /// This function exists for backwards compatibility.  You should use
132    /// [`ServerBuilder`] instead.
133    pub fn new(
134        config: &ConfigDropshot,
135        api: ApiDescription<C>,
136        private: C,
137        log: &Logger,
138    ) -> Result<HttpServerStarter<C>, GenericError> {
139        HttpServerStarter::new_with_tls(config, api, private, log, None)
140    }
141
142    /// Make an `HttpServerStarter` to start an `HttpServer`
143    ///
144    /// This function exists for backwards compatibility.  You should use
145    /// [`ServerBuilder`] instead.
146    pub fn new_with_tls(
147        config: &ConfigDropshot,
148        api: ApiDescription<C>,
149        private: C,
150        log: &Logger,
151        tls: Option<ConfigTls>,
152    ) -> Result<HttpServerStarter<C>, GenericError> {
153        ServerBuilder::new(api, private, log.clone())
154            .config(config.clone())
155            .tls(tls)
156            .build_starter()
157            .map_err(|e| Box::new(e) as GenericError)
158    }
159
160    fn new_internal(
161        config: &ConfigDropshot,
162        api: ApiDescription<C>,
163        private: C,
164        log: &Logger,
165        tls: Option<ConfigTls>,
166        version_policy: VersionPolicy,
167    ) -> Result<HttpServerStarter<C>, BuildError> {
168        let tcp = {
169            let std_listener = std::net::TcpListener::bind(
170                &config.bind_address,
171            )
172            .map_err(|e| BuildError::bind_error(e, config.bind_address))?;
173            std_listener.set_nonblocking(true).map_err(|e| {
174                BuildError::generic_system(e, "setting non-blocking")
175            })?;
176            // We use `from_std` instead of just calling `bind` here directly
177            // to avoid invoking an async function.
178            TcpListener::from_std(std_listener).map_err(|e| {
179                BuildError::generic_system(e, "creating TCP listener")
180            })?
181        };
182
183        let local_addr = tcp.local_addr().map_err(|e| {
184            BuildError::generic_system(e, "getting local TCP address")
185        })?;
186
187        let log = log.new(o!("local_addr" => local_addr));
188
189        let server_config = ServerConfig {
190            // We start aggressively to ensure test coverage.
191            default_request_body_max_bytes: config
192                .default_request_body_max_bytes,
193            page_max_nitems: NonZeroU32::new(10000).unwrap(),
194            page_default_nitems: NonZeroU32::new(100).unwrap(),
195            default_handler_task_mode: config.default_handler_task_mode,
196            log_headers: config.log_headers.clone(),
197            compression: config.compression,
198        };
199
200        let tls_acceptor = tls
201            .as_ref()
202            .map(|tls| {
203                Ok(Arc::new(Mutex::new(TlsAcceptor::from(Arc::new(
204                    rustls::ServerConfig::try_from(tls)?,
205                )))))
206            })
207            .transpose()?;
208        let handler_waitgroup = WaitGroup::new();
209
210        let router = api.into_router();
211        if let VersionPolicy::Unversioned = version_policy {
212            if router.has_versioned_routes() {
213                return Err(BuildError::UnversionedServerHasVersionedRoutes);
214            }
215        }
216
217        let app_state = Arc::new(DropshotState {
218            private,
219            config: server_config,
220            router,
221            log: log.clone(),
222            local_addr,
223            tls_acceptor: tls_acceptor.clone(),
224            handler_waitgroup_worker: DebugIgnore(handler_waitgroup.worker()),
225            version_policy,
226        });
227
228        for (path, method, endpoint) in app_state.router.endpoints(None) {
229            debug!(&log, "registered endpoint";
230                "method" => &method,
231                "path" => &path,
232                "versions" => &endpoint.versions,
233            );
234        }
235
236        let http_acceptor = HttpAcceptor { tcp, log: log.clone() };
237
238        Ok(HttpServerStarter {
239            app_state,
240            local_addr,
241            handler_waitgroup,
242            http_acceptor,
243            tls_acceptor,
244        })
245    }
246
247    pub fn start(self) -> HttpServer<C> {
248        let HttpServerStarter {
249            app_state,
250            local_addr,
251            handler_waitgroup,
252            tls_acceptor,
253            http_acceptor,
254        } = self;
255
256        let (tx, mut rx) = tokio::sync::oneshot::channel::<()>();
257        let make_service = ServerConnectionHandler::new(Arc::clone(&app_state));
258        let log = &app_state.log;
259        let log_close = log.clone();
260        let join_handle = tokio::spawn(async move {
261            use hyper_util::rt::{TokioExecutor, TokioIo, TokioTimer};
262            use hyper_util::server::conn::auto;
263
264            let mut builder = auto::Builder::new(TokioExecutor::new());
265            // http/1 settings
266            builder.http1().timer(TokioTimer::new());
267            // http/2 settings
268            builder.http2().timer(TokioTimer::new());
269
270            // Use a graceful watcher to keep track of all existing connections,
271            // and when the close_signal is trigger, force all known conns
272            // to start a graceful shutdown.
273            let graceful =
274                hyper_util::server::graceful::GracefulShutdown::new();
275
276            // The following code looks superficially similar between the HTTP
277            // and HTTPS paths.  However, the concrete types of various objects
278            // are different and so it's not easy to actually share the code.
279            let log = log_close;
280            match tls_acceptor {
281                Some(tls_acceptor) => {
282                    let mut https_acceptor = HttpsAcceptor::new(
283                        log.clone(),
284                        tls_acceptor,
285                        http_acceptor,
286                    );
287                    loop {
288                        tokio::select! {
289                            Some(Ok(sock)) = https_acceptor.accept() => {
290                                let remote_addr = sock.remote_addr();
291                                let handler = make_service
292                                    .make_http_request_handler(remote_addr);
293                                let fut = builder
294                                    .serve_connection_with_upgrades(
295                                        TokioIo::new(sock),
296                                        handler,
297                                    );
298                                let fut = graceful.watch(fut.into_owned());
299                                tokio::spawn(fut);
300                            },
301
302                            _ = &mut rx => {
303                                info!(log, "beginning graceful shutdown");
304                                break;
305                            }
306                        }
307                    }
308                }
309                None => loop {
310                    tokio::select! {
311                        (sock, remote_addr) = http_acceptor.accept() => {
312                            let handler = make_service
313                                .make_http_request_handler(remote_addr);
314                            let fut = builder
315                                .serve_connection_with_upgrades(
316                                    TokioIo::new(sock),
317                                    handler,
318                                );
319                            let fut = graceful.watch(fut.into_owned());
320                            tokio::spawn(fut);
321                        },
322
323                        _ = &mut rx => {
324                            info!(log, "beginning graceful shutdown");
325                            break;
326                        }
327                    }
328                },
329            };
330
331            // optional: could use another select on a timeout
332            graceful.shutdown().await
333        });
334
335        info!(log, "listening");
336
337        let join_handle = async move {
338            // After the server shuts down, we also want to wait for any
339            // detached handler futures to complete.
340            () = join_handle
341                .await
342                .map_err(|e| format!("server stopped: {e}"))?;
343            () = handler_waitgroup.wait().await;
344            Ok(())
345        };
346
347        #[cfg(feature = "usdt-probes")]
348        let probe_registration = match usdt::register_probes() {
349            Ok(_) => {
350                debug!(&log, "successfully registered DTrace USDT probes");
351                ProbeRegistration::Succeeded
352            }
353            Err(e) => {
354                let msg = e.to_string();
355                error!(&log, "failed to register DTrace USDT probes: {}", msg);
356                ProbeRegistration::Failed(msg)
357            }
358        };
359        #[cfg(not(feature = "usdt-probes"))]
360        let probe_registration = {
361            debug!(&log, "DTrace USDT probes compiled out, not registering");
362            ProbeRegistration::Disabled
363        };
364
365        HttpServer {
366            probe_registration,
367            app_state,
368            local_addr,
369            closer: CloseHandle { close_channel: Some(tx) },
370            join_future: join_handle.boxed().shared(),
371        }
372    }
373}
374
375/// Accepts TCP connections like a `TcpListener`, but ignores transient errors
376/// rather than propagating them to the caller
377struct HttpAcceptor {
378    tcp: TcpListener,
379    log: slog::Logger,
380}
381
382impl HttpAcceptor {
383    async fn accept(&self) -> (TcpStream, SocketAddr) {
384        loop {
385            match self.tcp.accept().await {
386                Ok((socket, addr)) => return (socket, addr),
387                Err(e) => match e.kind() {
388                    // These are errors on the individual socket that we
389                    // tried to accept, and so can be ignored.
390                    std::io::ErrorKind::ConnectionRefused
391                    | std::io::ErrorKind::ConnectionAborted
392                    | std::io::ErrorKind::ConnectionReset => (),
393
394                    // This could EMFILE implying resource exhaustion.
395                    // Sleep a little bit and try again.
396                    _ => {
397                        warn!(self.log, "accept error"; "error" => e);
398                        tokio::time::sleep(std::time::Duration::from_millis(
399                            100,
400                        ))
401                        .await;
402                    }
403                },
404            }
405        }
406    }
407}
408
409/// Wrapper for TlsStream<TcpStream> that also carries the remote SocketAddr
410#[derive(Debug)]
411struct TlsConn {
412    stream: TlsStream<TcpStream>,
413    remote_addr: SocketAddr,
414}
415
416impl TlsConn {
417    fn new(stream: TlsStream<TcpStream>, remote_addr: SocketAddr) -> TlsConn {
418        TlsConn { stream, remote_addr }
419    }
420
421    fn remote_addr(&self) -> SocketAddr {
422        self.remote_addr
423    }
424}
425
426/// Forward AsyncRead to the underlying stream
427impl tokio::io::AsyncRead for TlsConn {
428    fn poll_read(
429        mut self: Pin<&mut Self>,
430        ctx: &mut core::task::Context,
431        buf: &mut ReadBuf,
432    ) -> Poll<std::io::Result<()>> {
433        let pinned = Pin::new(&mut self.stream);
434        pinned.poll_read(ctx, buf)
435    }
436}
437
438/// Forward AsyncWrite to the underlying stream
439impl tokio::io::AsyncWrite for TlsConn {
440    fn poll_write(
441        mut self: Pin<&mut Self>,
442        ctx: &mut core::task::Context,
443        data: &[u8],
444    ) -> Poll<std::io::Result<usize>> {
445        let pinned = Pin::new(&mut self.stream);
446        pinned.poll_write(ctx, data)
447    }
448
449    fn poll_flush(
450        mut self: Pin<&mut Self>,
451        ctx: &mut core::task::Context,
452    ) -> Poll<std::io::Result<()>> {
453        let pinned = Pin::new(&mut self.stream);
454        pinned.poll_flush(ctx)
455    }
456
457    fn poll_shutdown(
458        mut self: Pin<&mut Self>,
459        ctx: &mut core::task::Context,
460    ) -> Poll<std::io::Result<()>> {
461        let pinned = Pin::new(&mut self.stream);
462        pinned.poll_shutdown(ctx)
463    }
464}
465
466/// This is our bridge between tokio-rustls and hyper. It implements
467/// `hyper::server::accept::Accept` interface, producing TLS-over-TCP
468/// connections.
469///
470/// Internally, it creates a stream that produces fully negotiated TLS
471/// connections as they come in from a TCP listen socket.  This stream allows
472/// for multiple TLS connections to be negotiated concurrently with new
473/// connections being accepted.
474struct HttpsAcceptor {
475    stream: Box<dyn Stream<Item = std::io::Result<TlsConn>> + Send + Unpin>,
476}
477
478impl HttpsAcceptor {
479    pub fn new(
480        log: slog::Logger,
481        tls_acceptor: Arc<Mutex<TlsAcceptor>>,
482        http_acceptor: HttpAcceptor,
483    ) -> HttpsAcceptor {
484        HttpsAcceptor {
485            stream: Box::new(Box::pin(Self::new_stream(
486                log,
487                tls_acceptor,
488                http_acceptor,
489            ))),
490        }
491    }
492
493    async fn accept(&mut self) -> Option<std::io::Result<TlsConn>> {
494        self.stream.next().await
495    }
496
497    fn new_stream(
498        log: slog::Logger,
499        tls_acceptor: Arc<Mutex<TlsAcceptor>>,
500        http_acceptor: HttpAcceptor,
501    ) -> impl Stream<Item = std::io::Result<TlsConn>> {
502        stream! {
503            let mut tls_negotiations = futures::stream::FuturesUnordered::new();
504            loop {
505                tokio::select! {
506                    Some(negotiation) = tls_negotiations.next(), if
507                            !tls_negotiations.is_empty() => {
508
509                        match negotiation {
510                            Ok(conn) => yield Ok(conn),
511                            Err(e) => {
512                                // If TLS negotiation fails, log the cause but
513                                // don't forward it along. Yielding an error
514                                // from here will terminate the server.
515                                // These failures may be a fatal TLS alert
516                                // message, or a client disconnection during
517                                // negotiation, or other issues.
518                                // TODO: We may want to export a counter for
519                                // different error types, since this may contain
520                                // useful things like "your certificate is
521                                // invalid"
522                                warn!(log, "tls accept err: {}", e);
523                            },
524                        }
525                    },
526                    (socket, addr) = http_acceptor.accept() => {
527                        let tls_negotiation = tls_acceptor
528                            .lock()
529                            .await
530                            .accept(socket)
531                            .map_ok(move |stream| TlsConn::new(stream, addr));
532                        tls_negotiations.push(tls_negotiation);
533                    },
534                    else => break,
535                }
536            }
537        }
538    }
539}
540
541/// Create a TLS configuration from the Dropshot config structure.
542impl TryFrom<&ConfigTls> for rustls::ServerConfig {
543    type Error = BuildError;
544
545    fn try_from(config: &ConfigTls) -> Result<Self, Self::Error> {
546        let (mut cert_reader, mut key_reader): (
547            Box<dyn std::io::BufRead>,
548            Box<dyn std::io::BufRead>,
549        ) = match config {
550            ConfigTls::Dynamic(raw) => {
551                return Ok(raw.clone());
552            }
553            ConfigTls::AsBytes { certs, key } => (
554                Box::new(std::io::BufReader::new(certs.as_slice())),
555                Box::new(std::io::BufReader::new(key.as_slice())),
556            ),
557            ConfigTls::AsFile { cert_file, key_file } => {
558                let certfile = Box::new(std::io::BufReader::new(
559                    std::fs::File::open(cert_file).map_err(|e| {
560                        BuildError::generic_system(
561                            e,
562                            format!("opening {}", cert_file.display()),
563                        )
564                    })?,
565                ));
566                let keyfile = Box::new(std::io::BufReader::new(
567                    std::fs::File::open(key_file).map_err(|e| {
568                        BuildError::generic_system(
569                            e,
570                            format!("opening {}", key_file.display()),
571                        )
572                    })?,
573                ));
574                (certfile, keyfile)
575            }
576        };
577
578        let certs = rustls_pemfile::certs(&mut cert_reader)
579            .collect::<Result<Vec<_>, _>>()
580            .map_err(|err| {
581                BuildError::generic_system(err, "loading TLS certificates")
582            })?;
583        let keys = rustls_pemfile::pkcs8_private_keys(&mut key_reader)
584            .collect::<Result<Vec<_>, _>>()
585            .map_err(|err| {
586                BuildError::generic_system(err, "loading TLS private key")
587            })?;
588        let mut keys_iter = keys.into_iter();
589        let (Some(private_key), None) = (keys_iter.next(), keys_iter.next())
590        else {
591            return Err(BuildError::NotOnePrivateKey);
592        };
593
594        let mut cfg = rustls::ServerConfig::builder()
595            .with_no_client_auth()
596            .with_single_cert(certs, private_key.into())
597            .expect("bad certificate/key");
598        cfg.alpn_protocols = vec![b"h2".to_vec(), b"http/1.1".to_vec()];
599        Ok(cfg)
600    }
601}
602
603type SharedBoxFuture<T> = Shared<Pin<Box<dyn Future<Output = T> + Send>>>;
604
605/// Future returned by [`HttpServer::wait_for_shutdown()`].
606pub struct ShutdownWaitFuture(SharedBoxFuture<Result<(), String>>);
607
608impl Future for ShutdownWaitFuture {
609    type Output = Result<(), String>;
610
611    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
612        Pin::new(&mut self.get_mut().0).poll(cx)
613    }
614}
615
616impl FusedFuture for ShutdownWaitFuture {
617    fn is_terminated(&self) -> bool {
618        self.0.is_terminated()
619    }
620}
621
622/// A running Dropshot HTTP server.
623///
624/// The generic traits represent the following:
625/// - C: Caller-supplied server context
626pub struct HttpServer<C: ServerContext> {
627    probe_registration: ProbeRegistration,
628    app_state: Arc<DropshotState<C>>,
629    local_addr: SocketAddr,
630    closer: CloseHandle,
631    join_future: SharedBoxFuture<Result<(), String>>,
632}
633
634// Handle used to trigger the shutdown of an [HttpServer].
635struct CloseHandle {
636    close_channel: Option<tokio::sync::oneshot::Sender<()>>,
637}
638
639impl<C: ServerContext> HttpServer<C> {
640    pub fn local_addr(&self) -> SocketAddr {
641        self.local_addr
642    }
643
644    pub fn app_private(&self) -> &C {
645        &self.app_state.private
646    }
647
648    pub fn using_tls(&self) -> bool {
649        self.app_state.using_tls()
650    }
651
652    /// Update TLS certificates for a running HTTPS server.
653    pub async fn refresh_tls(&self, config: &ConfigTls) -> Result<(), String> {
654        let acceptor = &self
655            .app_state
656            .tls_acceptor
657            .as_ref()
658            .ok_or_else(|| "Not configured for TLS".to_string())?;
659
660        *acceptor.lock().await = TlsAcceptor::from(Arc::new(
661            rustls::ServerConfig::try_from(config).unwrap(),
662        ));
663        Ok(())
664    }
665
666    /// Return the result of registering the server's DTrace USDT probes.
667    ///
668    /// See [`ProbeRegistration`] for details.
669    pub fn probe_registration(&self) -> &ProbeRegistration {
670        &self.probe_registration
671    }
672
673    /// Returns a future which completes when the server has shut down.
674    ///
675    /// This function does not cause the server to shut down. It just waits for
676    /// the shutdown to happen.
677    ///
678    /// To trigger a shutdown, Call [HttpServer::close] (which also awaits
679    /// shutdown).
680    pub fn wait_for_shutdown(&self) -> ShutdownWaitFuture {
681        ShutdownWaitFuture(self.join_future.clone())
682    }
683
684    /// Signals the currently running server to stop and waits for it to exit.
685    pub async fn close(mut self) -> Result<(), String> {
686        self.closer
687            .close_channel
688            .take()
689            .expect("cannot close twice")
690            .send(())
691            .expect("failed to send close signal");
692
693        // We _must_ explicitly drop our app state before awaiting join_future.
694        // If we are running handlers in `Detached` mode, our `app_state` has a
695        // `waitgroup::Worker` that they all clone, and `join_future` will await
696        // all of them being dropped. That means we must drop our "primary"
697        // clone of it, too!
698        mem::drop(self.app_state);
699
700        self.join_future.await
701    }
702}
703
704// For graceful termination, the `close()` function is preferred, as it can
705// report errors and wait for termination to complete.  However, we impl
706// `Drop` to attempt to shut down the server to handle less clean shutdowns
707// (e.g., from failing tests).
708impl Drop for CloseHandle {
709    fn drop(&mut self) {
710        if let Some(c) = self.close_channel.take() {
711            // The other side of this channel is owned by a separate tokio task
712            // that's running the hyper server.  We do not expect that to be
713            // cancelled.  But it can happen if the executor itself is shutting
714            // down and that task happens to get cleaned up before this one.
715            let _ = c.send(());
716        }
717    }
718}
719
720impl<C: ServerContext> Future for HttpServer<C> {
721    type Output = Result<(), String>;
722
723    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
724        let server = Pin::into_inner(self);
725        let join_future = Pin::new(&mut server.join_future);
726        join_future.poll(cx)
727    }
728}
729
730impl<C: ServerContext> FusedFuture for HttpServer<C> {
731    fn is_terminated(&self) -> bool {
732        self.join_future.is_terminated()
733    }
734}
735
736/// Initial entry point for handling a new request to the HTTP server.  This is
737/// invoked by Hyper when a new request is received.  This function returns a
738/// Result that either represents a valid HTTP response or an error (which will
739/// also get turned into an HTTP response).
740async fn http_request_handle_wrap<C: ServerContext>(
741    server: Arc<DropshotState<C>>,
742    remote_addr: SocketAddr,
743    request: Request<hyper::body::Incoming>,
744) -> Result<Response<Body>, GenericError> {
745    // This extra level of indirection makes error handling much more
746    // straightforward, since the request handling code can simply return early
747    // with an error and we'll treat it like an error from any of the endpoints
748    // themselves.
749    let start_time = std::time::Instant::now();
750    let request_id = generate_request_id();
751
752    let mut request_log = server.log.new(o!(
753        "remote_addr" => remote_addr,
754        "req_id" => request_id.clone(),
755        "method" => request.method().as_str().to_string(),
756        "uri" => format!("{}", request.uri()),
757    ));
758    // If we have been asked to include any headers from the request in the
759    // log messages, do so here:
760    for name in server.config.log_headers.iter() {
761        let v = request
762            .headers()
763            .get(name)
764            .and_then(|v| v.to_str().ok().map(str::to_string));
765
766        if let Some(v) = v {
767            // This is unfortunate in at least two ways: first, we would like to
768            // just construct _one_ key value map, but OwnedKV is opaque and can
769            // only be constructed with the o!() macro, so the only way to layer
770            // on a dynamic set of additional properties is by creating a chain
771            // of child loggers to add each one; second, we would like to be
772            // able to include all header values under a single map-valued
773            // "header" property, but slog only allows us a single-level
774            // property hierarchy.  Alas!
775            //
776            // We also replace the hyphens with underscores to make it easier to
777            // refer to the generated properties in dynamic languages used for
778            // filtering like rhai.
779            let k = format!("hdr_{}", name.to_lowercase().replace('-', "_"));
780            request_log = request_log.new(o!(k => v));
781        }
782    }
783
784    trace!(request_log, "incoming request");
785    #[cfg(feature = "usdt-probes")]
786    probes::request__start!(|| {
787        let uri = request.uri();
788        crate::dtrace::RequestInfo {
789            id: request_id.clone(),
790            local_addr: server.local_addr,
791            remote_addr,
792            method: request.method().to_string(),
793            path: uri.path().to_string(),
794            query: uri.query().map(|x| x.to_string()),
795        }
796    });
797
798    // Copy local address to report later during the finish probe, as the
799    // server is passed by value to the request handler function.
800    #[cfg(feature = "usdt-probes")]
801    let local_addr = server.local_addr;
802
803    // In the case the client disconnects early, the scopeguard allows us
804    // to perform extra housekeeping before this task is dropped.
805    let on_disconnect = guard((), |_| {
806        let latency_us = start_time.elapsed().as_micros();
807
808        warn!(request_log, "request handling cancelled (client disconnected)";
809            "latency_us" => latency_us,
810        );
811
812        #[cfg(feature = "usdt-probes")]
813        probes::request__done!(|| {
814            crate::dtrace::ResponseInfo {
815                id: request_id.clone(),
816                local_addr,
817                remote_addr,
818                // 499 is a non-standard code popularized by nginx to mean "client disconnected".
819                status_code: 499,
820                message: String::from(
821                    "client disconnected before response returned",
822                ),
823            }
824        });
825    });
826
827    let maybe_response = http_request_handle(
828        server,
829        request,
830        &request_id,
831        request_log.new(o!()),
832        remote_addr,
833    )
834    .await;
835
836    // If `http_request_handle` completed, it means the request wasn't
837    // cancelled and we can safely "defuse" the scopeguard.
838    let _ = ScopeGuard::into_inner(on_disconnect);
839
840    let latency_us = start_time.elapsed().as_micros();
841    let response = match maybe_response {
842        Err(error) => {
843            {
844                let status = error.status_code();
845                let message_external = error.external_message();
846                let message_internal = error.internal_message();
847
848                #[cfg(feature = "usdt-probes")]
849                probes::request__done!(|| {
850                    crate::dtrace::ResponseInfo {
851                        id: request_id.clone(),
852                        local_addr,
853                        remote_addr,
854                        status_code: status.as_u16(),
855                        message: message_external
856                            .cloned()
857                            .unwrap_or_else(|| message_internal.clone()),
858                    }
859                });
860
861                // TODO-debug: add request and response headers here
862                info!(request_log, "request completed";
863                    "response_code" => status.as_u16(),
864                    "latency_us" => latency_us,
865                    "error_message_internal" => message_internal,
866                    "error_message_external" => message_external,
867                );
868            };
869            error.into_response(&request_id)
870        }
871
872        Ok(response) => {
873            // TODO-debug: add request and response headers here
874            info!(request_log, "request completed";
875                "response_code" => response.status().as_u16(),
876                "latency_us" => latency_us,
877            );
878
879            #[cfg(feature = "usdt-probes")]
880            probes::request__done!(|| {
881                crate::dtrace::ResponseInfo {
882                    id: request_id.parse().unwrap(),
883                    local_addr,
884                    remote_addr,
885                    status_code: response.status().as_u16(),
886                    message: "".to_string(),
887                }
888            });
889
890            response
891        }
892    };
893
894    Ok(response)
895}
896
897async fn http_request_handle<C: ServerContext>(
898    server: Arc<DropshotState<C>>,
899    request: Request<hyper::body::Incoming>,
900    request_id: &str,
901    request_log: Logger,
902    remote_addr: std::net::SocketAddr,
903) -> Result<Response<Body>, HandlerError> {
904    // TODO-hardening: is it correct to (and do we correctly) read the entire
905    // request body even if we decide it's too large and are going to send a 400
906    // response?
907    // TODO-hardening: add a request read timeout as well so that we don't allow
908    // this to take forever.
909    // TODO-correctness: Do we need to dump the body on errors?
910    let request = request.map(crate::Body::wrap);
911    let method = request.method().clone();
912    let uri = request.uri();
913    let found_version =
914        server.version_policy.request_version(&request, &request_log)?;
915    let lookup_result = server.router.lookup_route(
916        &method,
917        uri.path().into(),
918        found_version.as_ref(),
919    )?;
920    let rqctx = RequestContext {
921        server: Arc::clone(&server),
922        request: RequestInfo::new(&request, remote_addr),
923        endpoint: lookup_result.endpoint,
924        request_id: request_id.to_string(),
925        log: request_log.clone(),
926    };
927    let request_headers = rqctx.request.headers().clone();
928    let handler = lookup_result.handler;
929
930    let mut response = match server.config.default_handler_task_mode {
931        HandlerTaskMode::CancelOnDisconnect => {
932            // For CancelOnDisconnect, we run the request handler directly: if
933            // the client disconnects, we will be cancelled, and therefore this
934            // future will too.
935            handler.handle_request(rqctx, request).await?
936        }
937        HandlerTaskMode::Detached => {
938            // Spawn the handler so if we're cancelled, the handler still runs
939            // to completion.
940            let (tx, rx) = oneshot::channel();
941            let request_log = request_log.clone();
942            let worker = server.handler_waitgroup_worker.clone();
943            let handler_task = tokio::spawn(async move {
944                let request_log = rqctx.log.clone();
945                let result = handler.handle_request(rqctx, request).await;
946
947                // If this send fails, our spawning task has been cancelled in
948                // the `rx.await` below; log such a result.
949                if let Err(result) = tx.send(result) {
950                    match result {
951                        Ok(r) => warn!(
952                            request_log, "request completed after handler was already cancelled";
953                            "response_code" => r.status().as_u16(),
954                        ),
955                        Err(error) => {
956                            warn!(request_log, "request completed after handler was already cancelled";
957                                "response_code" => error.status_code().as_u16(),
958                                "error_message_internal" => error.internal_message(),
959                                "error_message_external" => error.external_message(),
960                            );
961                        }
962                    }
963                }
964
965                // Drop our waitgroup worker, allowing graceful shutdown to
966                // complete (if it's waiting on us).
967                mem::drop(worker);
968            });
969
970            // The only way we can fail to receive on `rx` is if `tx` is
971            // dropped before a result is sent, which can only happen if
972            // `handle_request` panics. We will propagate such a panic here,
973            // just as we would have in `CancelOnDisconnect` mode above (where
974            // we call the handler directly).
975            match rx.await {
976                Ok(result) => result?,
977                Err(_) => {
978                    error!(request_log, "handler panicked; propagating panic");
979
980                    // To get the panic, we now need to await `handler_task`; we
981                    // know it is complete _and_ it failed, because it has
982                    // dropped `tx` without sending us a result, which is only
983                    // possible if it panicked.
984                    let task_err = handler_task.await.expect_err(
985                        "task failed to send result but didn't panic",
986                    );
987                    panic::resume_unwind(task_err.into_panic());
988                }
989            }
990        }
991    };
992
993    if matches!(server.config.compression, CompressionConfig::Gzip)
994        && is_compressible_content_type(response.headers())
995    {
996        // Add Vary: Accept-Encoding header for all compressible content
997        // types. This needs to be there even if the response ends up not being
998        // compressed because it tells caches (like browsers and CDNs) that the
999        // response content depends on the value of the Accept-Encoding header.
1000        // Without this, a cache might mistakenly serve a compressed response to
1001        // a client that cannot decompress it, or serve an uncompressed response
1002        // to a client that could have benefited from compression.
1003        add_vary_header(response.headers_mut());
1004
1005        if should_compress_response(
1006            &method,
1007            &request_headers,
1008            response.status(),
1009            response.headers(),
1010            response.extensions(),
1011        ) {
1012            response = apply_gzip_compression(response);
1013        }
1014    }
1015
1016    response.headers_mut().insert(
1017        HEADER_REQUEST_ID,
1018        http::header::HeaderValue::from_str(&request_id).unwrap(),
1019    );
1020    Ok(response)
1021}
1022
1023// This function should probably be parametrized by some name of the service
1024// that is expected to be unique within an organization.  That way, it would be
1025// possible to determine from a given request id which service it was from.
1026// TODO should we encode more information here?  Service?  Instance?  Time up to
1027// the hour?
1028fn generate_request_id() -> String {
1029    format!("{}", Uuid::new_v4())
1030}
1031
1032/// ServerConnectionHandler is a Hyper Service implementation that forwards
1033/// incoming connections to `http_connection_handle()`, providing the server
1034/// state object as an additional argument.  We could use `make_service_fn` here
1035/// using a closure to capture the state object, but the resulting code is a bit
1036/// simpler without it.
1037pub struct ServerConnectionHandler<C: ServerContext> {
1038    /// backend state that will be made available to the connection handler
1039    server: Arc<DropshotState<C>>,
1040}
1041
1042impl<C: ServerContext> ServerConnectionHandler<C> {
1043    /// Create an ServerConnectionHandler with the given state object that
1044    /// will be made available to the handler.
1045    fn new(server: Arc<DropshotState<C>>) -> Self {
1046        ServerConnectionHandler { server }
1047    }
1048
1049    /// Initial entry point for handling a new connection to the HTTP server.
1050    /// This is invoked by Hyper when a new connection is accepted.  This function
1051    /// must return a Hyper Service object that will handle requests for this
1052    /// connection.
1053    fn make_http_request_handler(
1054        &self,
1055        remote_addr: SocketAddr,
1056    ) -> ServerRequestHandler<C> {
1057        info!(self.server.log, "accepted connection"; "remote_addr" => %remote_addr);
1058        ServerRequestHandler::new(self.server.clone(), remote_addr)
1059    }
1060}
1061
1062/// ServerRequestHandler is a Hyper Service implementation that forwards
1063/// incoming requests to `http_request_handle_wrap()`, including as an argument
1064/// the backend server state object.  We could use `service_fn` here using a
1065/// closure to capture the server state object, but the resulting code is a bit
1066/// simpler without all that.
1067pub struct ServerRequestHandler<C: ServerContext> {
1068    /// backend state that will be made available to the request handler
1069    server: Arc<DropshotState<C>>,
1070    remote_addr: SocketAddr,
1071}
1072
1073impl<C: ServerContext> ServerRequestHandler<C> {
1074    /// Create a ServerRequestHandler object with the given state object that
1075    /// will be provided to the handler function.
1076    fn new(server: Arc<DropshotState<C>>, remote_addr: SocketAddr) -> Self {
1077        ServerRequestHandler { server, remote_addr }
1078    }
1079}
1080
1081impl<C: ServerContext> Service<Request<hyper::body::Incoming>>
1082    for ServerRequestHandler<C>
1083{
1084    type Response = Response<Body>;
1085    type Error = GenericError;
1086    type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;
1087
1088    fn call(&self, req: Request<hyper::body::Incoming>) -> Self::Future {
1089        Box::pin(http_request_handle_wrap(
1090            Arc::clone(&self.server),
1091            self.remote_addr,
1092            req,
1093        ))
1094    }
1095}
1096
1097/// Errors encountered while configuring a Dropshot server
1098#[derive(Debug, Error)]
1099pub enum BuildError {
1100    #[error("failed to bind to {address}")]
1101    BindError {
1102        address: SocketAddr,
1103        #[source]
1104        error: std::io::Error,
1105    },
1106    #[error("expected exactly one TLS private key")]
1107    NotOnePrivateKey,
1108    #[error("{context}")]
1109    SystemError {
1110        context: String,
1111        #[source]
1112        error: std::io::Error,
1113    },
1114    #[error(
1115        "unversioned servers cannot have endpoints with specific versions"
1116    )]
1117    UnversionedServerHasVersionedRoutes,
1118}
1119
1120impl BuildError {
1121    /// Generate an error for failure to bind to `address`
1122    fn bind_error(error: std::io::Error, address: SocketAddr) -> BuildError {
1123        BuildError::BindError { address, error }
1124    }
1125
1126    /// Generate an error for any kind of `std::io::Error`
1127    ///
1128    /// `context` describes more about what we were trying to do that generated
1129    /// the error.
1130    fn generic_system<S: Into<String>>(
1131        error: std::io::Error,
1132        context: S,
1133    ) -> BuildError {
1134        BuildError::SystemError { context: context.into(), error }
1135    }
1136}
1137
1138/// Start configuring a Dropshot server
1139#[derive(Debug)]
1140pub struct ServerBuilder<C: ServerContext> {
1141    // required caller-provided values
1142    private: C,
1143    log: Logger,
1144    api: DebugIgnore<ApiDescription<C>>,
1145
1146    // optional caller-provided values
1147    config: ConfigDropshot,
1148    version_policy: VersionPolicy,
1149    tls: Option<ConfigTls>,
1150}
1151
1152impl<C: ServerContext> ServerBuilder<C> {
1153    /// Start configuring a new Dropshot server
1154    ///
1155    /// * `api`: the API to be hosted on this server
1156    /// * `private`: your private data that will be made available in
1157    ///   `RequestContext`
1158    /// * `log`: a slog logger for all server events
1159    pub fn new(
1160        api: ApiDescription<C>,
1161        private: C,
1162        log: Logger,
1163    ) -> ServerBuilder<C> {
1164        ServerBuilder {
1165            private,
1166            log,
1167            api: DebugIgnore(api),
1168            config: Default::default(),
1169            version_policy: VersionPolicy::Unversioned,
1170            tls: Default::default(),
1171        }
1172    }
1173
1174    /// Specify the server configuration
1175    pub fn config(mut self, config: ConfigDropshot) -> Self {
1176        self.config = config;
1177        self
1178    }
1179
1180    /// Specify the TLS configuration, if any
1181    ///
1182    /// `None` (the default) means no TLS.  The server will listen for plain
1183    /// HTTP.
1184    pub fn tls(mut self, tls: Option<ConfigTls>) -> Self {
1185        self.tls = tls;
1186        self
1187    }
1188
1189    /// Specifies whether and how this server determines the API version to use
1190    /// for incoming requests
1191    ///
1192    /// All the interfaces related to [`VersionPolicy`] are considered
1193    /// experimental and may change in an upcoming release.
1194    pub fn version_policy(mut self, version_policy: VersionPolicy) -> Self {
1195        self.version_policy = version_policy;
1196        self
1197    }
1198
1199    /// Start the server
1200    ///
1201    /// # Errors
1202    ///
1203    /// See [`ServerBuilder::build_starter()`].
1204    pub fn start(self) -> Result<HttpServer<C>, BuildError> {
1205        Ok(self.build_starter()?.start())
1206    }
1207
1208    /// Build an `HttpServerStarter` that can be used to start the server
1209    ///
1210    /// Most consumers probably want to use `start()` instead.
1211    ///
1212    /// # Errors
1213    ///
1214    /// This fails if:
1215    ///
1216    /// * We could not bind to the requested IP address and TCP port
1217    /// * The provided `tls` configuration was not valid
1218    /// * The `version_policy` is `VersionPolicy::Unversioned` and `api` (the
1219    ///   `ApiDescription`) contains any endpoints that are version-restricted
1220    ///   (i.e., have "versions" set to anything other than
1221    ///   `ApiEndpointVersions::All)`.  Versioned routes are not supported with
1222    ///   unversioned servers.
1223    pub fn build_starter(self) -> Result<HttpServerStarter<C>, BuildError> {
1224        HttpServerStarter::new_internal(
1225            &self.config,
1226            self.api.0,
1227            self.private,
1228            &self.log,
1229            self.tls,
1230            self.version_policy,
1231        )
1232    }
1233}
1234
1235#[cfg(test)]
1236mod test {
1237    use super::*;
1238    // Referring to the current crate as "dropshot::" instead of "crate::"
1239    // helps the endpoint macro with module lookup.
1240    use crate as dropshot;
1241    use dropshot::endpoint;
1242    use dropshot::test_util::ClientTestContext;
1243    use dropshot::test_util::LogContext;
1244    use dropshot::ConfigLogging;
1245    use dropshot::ConfigLoggingLevel;
1246    use dropshot::HttpError;
1247    use dropshot::HttpResponseOk;
1248    use dropshot::RequestContext;
1249    use http::StatusCode;
1250    use hyper::Method;
1251
1252    use futures::future::FusedFuture;
1253
1254    #[endpoint {
1255        method = GET,
1256        path = "/handler",
1257    }]
1258    async fn handler(
1259        _rqctx: RequestContext<i32>,
1260    ) -> Result<HttpResponseOk<u64>, HttpError> {
1261        Ok(HttpResponseOk(3))
1262    }
1263
1264    struct TestConfig {
1265        log_context: LogContext,
1266    }
1267
1268    impl TestConfig {
1269        fn log(&self) -> &slog::Logger {
1270            &self.log_context.log
1271        }
1272    }
1273
1274    fn create_test_server() -> (HttpServer<i32>, TestConfig) {
1275        let config_dropshot = ConfigDropshot::default();
1276
1277        let mut api = ApiDescription::new();
1278        api.register(handler).unwrap();
1279
1280        let config_logging =
1281            ConfigLogging::StderrTerminal { level: ConfigLoggingLevel::Warn };
1282        let log_context = LogContext::new("test server", &config_logging);
1283        let log = &log_context.log;
1284
1285        let server = HttpServerStarter::new(&config_dropshot, api, 0, log)
1286            .unwrap()
1287            .start();
1288
1289        (server, TestConfig { log_context })
1290    }
1291
1292    async fn single_client_request(addr: SocketAddr, log: &slog::Logger) {
1293        let client_log = log.new(o!("http_client" => "dropshot test suite"));
1294        let client_testctx = ClientTestContext::new(addr, client_log);
1295        tokio::task::spawn(async move {
1296            let response = client_testctx
1297                .make_request(
1298                    Method::GET,
1299                    "/handler",
1300                    None as Option<()>,
1301                    StatusCode::OK,
1302                )
1303                .await;
1304
1305            assert!(response.is_ok());
1306        })
1307        .await
1308        .expect("client request failed");
1309    }
1310
1311    #[tokio::test]
1312    async fn test_server_run_then_close() {
1313        let (mut server, config) = create_test_server();
1314        let client = single_client_request(server.local_addr(), config.log());
1315
1316        futures::select! {
1317            _ = client.fuse() => {},
1318            r = server => panic!("Server unexpectedly terminated: {:?}", r),
1319        }
1320
1321        assert!(!server.is_terminated());
1322        assert!(server.close().await.is_ok());
1323    }
1324
1325    #[tokio::test]
1326    async fn test_drop_server_without_close_okay() {
1327        let (server, _) = create_test_server();
1328        std::mem::drop(server);
1329    }
1330
1331    #[tokio::test]
1332    async fn test_http_acceptor_happy_path() {
1333        const TOTAL: usize = 100;
1334        let tcp =
1335            tokio::net::TcpListener::bind("127.0.0.1:0").await.expect("bind");
1336        let addr = tcp.local_addr().expect("local_addr");
1337        let acceptor =
1338            HttpAcceptor { log: slog::Logger::root(slog::Discard, o!()), tcp };
1339
1340        let t1 = tokio::spawn(async move {
1341            for _ in 0..TOTAL {
1342                let _ = acceptor.accept().await;
1343            }
1344        });
1345
1346        let t2 = tokio::spawn(async move {
1347            for _ in 0..TOTAL {
1348                tokio::net::TcpStream::connect(&addr).await.expect("connect");
1349            }
1350        });
1351
1352        t1.await.expect("task 1");
1353        t2.await.expect("task 2");
1354    }
1355}