tonic_rustls/server/
mod.rs

1//! Server implementation and builder.
2
3mod conn;
4mod incoming;
5mod service;
6#[cfg(unix)]
7mod unix;
8
9use tokio_stream::StreamExt as _;
10use tracing::{debug, trace};
11
12use tonic::service::Routes;
13
14pub use conn::{Connected, TcpConnectInfo};
15use hyper_util::{
16    rt::{TokioExecutor, TokioIo, TokioTimer},
17    server::conn::auto::{Builder as ConnectionBuilder, HttpServerConnExec},
18    service::TowerToHyperService,
19};
20
21#[cfg(feature = "tls")]
22pub use conn::TlsConnectInfo;
23
24#[cfg(feature = "tls")]
25use self::service::TlsAcceptor;
26
27#[cfg(unix)]
28pub use unix::UdsConnectInfo;
29
30pub use incoming::TcpIncoming;
31
32#[cfg(feature = "tls")]
33use crate::Error;
34
35use self::service::{RecoverError, ServerIo};
36use super::service::GrpcTimeout;
37use bytes::Bytes;
38use http::{Request, Response};
39use http_body_util::BodyExt;
40use hyper::{body::Incoming, service::Service as HyperService};
41use pin_project::pin_project;
42use std::future::pending;
43use std::{
44    convert::Infallible,
45    fmt,
46    future::{self, poll_fn, Future},
47    marker::PhantomData,
48    net::SocketAddr,
49    pin::{pin, Pin},
50    sync::Arc,
51    task::{ready, Context, Poll},
52    time::Duration,
53};
54use tokio::io::{AsyncRead, AsyncWrite};
55use tokio::time::sleep;
56use tokio_stream::Stream;
57use tonic::body::Body;
58use tonic::server::NamedService;
59use tower::{
60    layer::util::{Identity, Stack},
61    layer::Layer,
62    limit::concurrency::ConcurrencyLimitLayer,
63    util::BoxCloneService,
64    Service, ServiceBuilder, ServiceExt,
65};
66
67type BoxService = tower::util::BoxCloneService<Request<Body>, Response<Body>, crate::BoxError>;
68type TraceInterceptor = Arc<dyn Fn(&http::Request<()>) -> tracing::Span + Send + Sync + 'static>;
69
70const DEFAULT_HTTP2_KEEPALIVE_TIMEOUT_SECS: u64 = 20;
71
72/// A default batteries included `transport` server.
73///
74/// This provides an easy builder pattern style builder [`Server`] on top of
75/// `hyper` connections. This builder exposes easy configuration parameters
76/// for providing a fully featured http2 based gRPC server. This should provide
77/// a very good out of the box http2 server for use with tonic but is also a
78/// reference implementation that should be a good starting point for anyone
79/// wanting to create a more complex and/or specific implementation.
80#[derive(Clone)]
81pub struct Server<L = Identity> {
82    trace_interceptor: Option<TraceInterceptor>,
83    concurrency_limit: Option<usize>,
84    timeout: Option<Duration>,
85    #[cfg(feature = "tls")]
86    tls: Option<TlsAcceptor>,
87    init_stream_window_size: Option<u32>,
88    init_connection_window_size: Option<u32>,
89    max_concurrent_streams: Option<u32>,
90    tcp_keepalive: Option<Duration>,
91    tcp_nodelay: bool,
92    http2_keepalive_interval: Option<Duration>,
93    http2_keepalive_timeout: Option<Duration>,
94    http2_adaptive_window: Option<bool>,
95    http2_max_pending_accept_reset_streams: Option<usize>,
96    http2_max_header_list_size: Option<u32>,
97    max_frame_size: Option<u32>,
98    accept_http1: bool,
99    service_builder: ServiceBuilder<L>,
100    max_connection_age: Option<Duration>,
101}
102
103impl Default for Server<Identity> {
104    fn default() -> Self {
105        Self {
106            trace_interceptor: None,
107            concurrency_limit: None,
108            timeout: None,
109            #[cfg(feature = "tls")]
110            tls: None,
111            init_stream_window_size: None,
112            init_connection_window_size: None,
113            max_concurrent_streams: None,
114            tcp_keepalive: None,
115            tcp_nodelay: false,
116            http2_keepalive_interval: None,
117            http2_keepalive_timeout: None,
118            http2_adaptive_window: None,
119            http2_max_pending_accept_reset_streams: None,
120            http2_max_header_list_size: None,
121            max_frame_size: None,
122            accept_http1: false,
123            service_builder: Default::default(),
124            max_connection_age: None,
125        }
126    }
127}
128
129/// A stack based [`Service`] router.
130#[derive(Debug)]
131pub struct Router<L = Identity> {
132    server: Server<L>,
133    routes: Routes,
134}
135
136impl Server {
137    /// Create a new server builder that can configure a [`Server`].
138    pub fn builder() -> Self {
139        Server {
140            tcp_nodelay: true,
141            accept_http1: false,
142            ..Default::default()
143        }
144    }
145}
146
147impl<L> Server<L> {
148    /// Configure TLS for this server.
149    #[cfg(feature = "tls")]
150    pub fn tls_config(self, tls_config: tokio_rustls::rustls::ServerConfig) -> Result<Self, Error> {
151        let tls_acceptor = TlsAcceptor::new(tls_config).map_err(Error::from_source)?;
152        Ok(Server {
153            tls: Some(tls_acceptor),
154            ..self
155        })
156    }
157
158    /// Set the concurrency limit applied to on requests inbound per connection.
159    ///
160    /// # Example
161    ///
162    /// ```
163    /// # use tonic_rustls::Server;
164    /// # use tower_service::Service;
165    /// # let builder = Server::builder();
166    /// builder.concurrency_limit_per_connection(32);
167    /// ```
168    #[must_use]
169    pub fn concurrency_limit_per_connection(self, limit: usize) -> Self {
170        Server {
171            concurrency_limit: Some(limit),
172            ..self
173        }
174    }
175
176    /// Set a timeout on for all request handlers.
177    ///
178    /// # Example
179    ///
180    /// ```
181    /// # use tonic_rustls::Server;
182    /// # use tower_service::Service;
183    /// # use std::time::Duration;
184    /// # let builder = Server::builder();
185    /// builder.timeout(Duration::from_secs(30));
186    /// ```
187    #[must_use]
188    pub fn timeout(self, timeout: Duration) -> Self {
189        Server {
190            timeout: Some(timeout),
191            ..self
192        }
193    }
194
195    /// Sets the [`SETTINGS_INITIAL_WINDOW_SIZE`][spec] option for HTTP2
196    /// stream-level flow control.
197    ///
198    /// Default is 65,535
199    ///
200    /// [spec]: https://httpwg.org/specs/rfc9113.html#InitialWindowSize
201    #[must_use]
202    pub fn initial_stream_window_size(self, sz: impl Into<Option<u32>>) -> Self {
203        Server {
204            init_stream_window_size: sz.into(),
205            ..self
206        }
207    }
208
209    /// Sets the max connection-level flow control for HTTP2
210    ///
211    /// Default is 65,535
212    #[must_use]
213    pub fn initial_connection_window_size(self, sz: impl Into<Option<u32>>) -> Self {
214        Server {
215            init_connection_window_size: sz.into(),
216            ..self
217        }
218    }
219
220    /// Sets the [`SETTINGS_MAX_CONCURRENT_STREAMS`][spec] option for HTTP2
221    /// connections.
222    ///
223    /// Default is no limit (`None`).
224    ///
225    /// [spec]: https://httpwg.org/specs/rfc9113.html#n-stream-concurrency
226    #[must_use]
227    pub fn max_concurrent_streams(self, max: impl Into<Option<u32>>) -> Self {
228        Server {
229            max_concurrent_streams: max.into(),
230            ..self
231        }
232    }
233
234    /// Sets the maximum time option in milliseconds that a connection may exist
235    ///
236    /// Default is no limit (`None`).
237    ///
238    /// # Example
239    ///
240    /// ```
241    /// # use tonic_rustls::Server;
242    /// # use tower_service::Service;
243    /// # use std::time::Duration;
244    /// # let builder = Server::builder();
245    /// builder.max_connection_age(Duration::from_secs(60));
246    /// ```
247    #[must_use]
248    pub fn max_connection_age(self, max_connection_age: Duration) -> Self {
249        Server {
250            max_connection_age: Some(max_connection_age),
251            ..self
252        }
253    }
254
255    /// Set whether HTTP2 Ping frames are enabled on accepted connections.
256    ///
257    /// If `None` is specified, HTTP2 keepalive is disabled, otherwise the duration
258    /// specified will be the time interval between HTTP2 Ping frames.
259    /// The timeout for receiving an acknowledgement of the keepalive ping
260    /// can be set with [`Server::http2_keepalive_timeout`].
261    ///
262    /// Default is no HTTP2 keepalive (`None`)
263    ///
264    #[must_use]
265    pub fn http2_keepalive_interval(self, http2_keepalive_interval: Option<Duration>) -> Self {
266        Server {
267            http2_keepalive_interval,
268            ..self
269        }
270    }
271
272    /// Sets a timeout for receiving an acknowledgement of the keepalive ping.
273    ///
274    /// If the ping is not acknowledged within the timeout, the connection will be closed.
275    /// Does nothing if http2_keep_alive_interval is disabled.
276    ///
277    /// Default is 20 seconds.
278    ///
279    #[must_use]
280    pub fn http2_keepalive_timeout(self, http2_keepalive_timeout: Option<Duration>) -> Self {
281        Server {
282            http2_keepalive_timeout,
283            ..self
284        }
285    }
286
287    /// Sets whether to use an adaptive flow control. Defaults to false.
288    /// Enabling this will override the limits set in http2_initial_stream_window_size and
289    /// http2_initial_connection_window_size.
290    #[must_use]
291    pub fn http2_adaptive_window(self, enabled: Option<bool>) -> Self {
292        Server {
293            http2_adaptive_window: enabled,
294            ..self
295        }
296    }
297
298    /// Configures the maximum number of pending reset streams allowed before a GOAWAY will be sent.
299    ///
300    /// This will default to whatever the default in h2 is. As of v0.3.17, it is 20.
301    ///
302    /// See <https://github.com/hyperium/hyper/issues/2877> for more information.
303    #[must_use]
304    pub fn http2_max_pending_accept_reset_streams(self, max: Option<usize>) -> Self {
305        Server {
306            http2_max_pending_accept_reset_streams: max,
307            ..self
308        }
309    }
310
311    /// Set whether TCP keepalive messages are enabled on accepted connections.
312    ///
313    /// If `None` is specified, keepalive is disabled, otherwise the duration
314    /// specified will be the time to remain idle before sending TCP keepalive
315    /// probes.
316    ///
317    /// Default is no keepalive (`None`)
318    ///
319    #[must_use]
320    pub fn tcp_keepalive(self, tcp_keepalive: Option<Duration>) -> Self {
321        Server {
322            tcp_keepalive,
323            ..self
324        }
325    }
326
327    /// Set the value of `TCP_NODELAY` option for accepted connections. Enabled by default.
328    #[must_use]
329    pub fn tcp_nodelay(self, enabled: bool) -> Self {
330        Server {
331            tcp_nodelay: enabled,
332            ..self
333        }
334    }
335
336    /// Sets the max size of received header frames.
337    ///
338    /// This will default to whatever the default in hyper is. As of v1.4.1, it is 16 KiB.
339    #[must_use]
340    pub fn http2_max_header_list_size(self, max: impl Into<Option<u32>>) -> Self {
341        Server {
342            http2_max_header_list_size: max.into(),
343            ..self
344        }
345    }
346
347    /// Sets the maximum frame size to use for HTTP2.
348    ///
349    /// Passing `None` will do nothing.
350    ///
351    /// If not set, will default from underlying transport.
352    #[must_use]
353    pub fn max_frame_size(self, frame_size: impl Into<Option<u32>>) -> Self {
354        Server {
355            max_frame_size: frame_size.into(),
356            ..self
357        }
358    }
359
360    /// Allow this server to accept http1 requests.
361    ///
362    /// Accepting http1 requests is only useful when developing `grpc-web`
363    /// enabled services. If this setting is set to `true` but services are
364    /// not correctly configured to handle grpc-web requests, your server may
365    /// return confusing (but correct) protocol errors.
366    ///
367    /// Default is `false`.
368    #[must_use]
369    pub fn accept_http1(self, accept_http1: bool) -> Self {
370        Server {
371            accept_http1,
372            ..self
373        }
374    }
375
376    /// Intercept inbound headers and add a [`tracing::Span`] to each response future.
377    #[must_use]
378    pub fn trace_fn<F>(self, f: F) -> Self
379    where
380        F: Fn(&http::Request<()>) -> tracing::Span + Send + Sync + 'static,
381    {
382        Server {
383            trace_interceptor: Some(Arc::new(f)),
384            ..self
385        }
386    }
387
388    /// Create a router with the `S` typed service as the first service.
389    ///
390    /// This will clone the `Server` builder and create a router that will
391    /// route around different services.
392    pub fn add_service<S>(&mut self, svc: S) -> Router<L>
393    where
394        S: Service<Request<Body>, Response = Response<Body>, Error = Infallible>
395            + NamedService
396            + Clone
397            + Send
398            + Sync
399            + 'static,
400        S::Future: Send + 'static,
401        L: Clone,
402    {
403        Router::new(self.clone(), Routes::new(svc))
404    }
405
406    /// Create a router with the optional `S` typed service as the first service.
407    ///
408    /// This will clone the `Server` builder and create a router that will
409    /// route around different services.
410    ///
411    /// # Note
412    /// Even when the argument given is `None` this will capture *all* requests to this service name.
413    /// As a result, one cannot use this to toggle between two identically named implementations.
414    pub fn add_optional_service<S>(&mut self, svc: Option<S>) -> Router<L>
415    where
416        S: Service<Request<Body>, Response = Response<Body>, Error = Infallible>
417            + NamedService
418            + Clone
419            + Send
420            + Sync
421            + 'static,
422        S::Future: Send + 'static,
423        L: Clone,
424    {
425        let routes = svc.map(Routes::new).unwrap_or_default();
426        Router::new(self.clone(), routes)
427    }
428
429    /// Create a router with given [`Routes`].
430    ///
431    /// This will clone the `Server` builder and create a router that will
432    /// route around different services that were already added to the provided `routes`.
433    pub fn add_routes(&mut self, routes: Routes) -> Router<L>
434    where
435        L: Clone,
436    {
437        Router::new(self.clone(), routes)
438    }
439
440    /// Set the [Tower] [`Layer`] all services will be wrapped in.
441    ///
442    /// This enables using middleware from the [Tower ecosystem][eco].
443    ///
444    /// # Example
445    ///
446    /// ```
447    /// # use tonic_rustls::Server;
448    /// # use tower_service::Service;
449    /// use tower::timeout::TimeoutLayer;
450    /// use std::time::Duration;
451    ///
452    /// # let mut builder = Server::builder();
453    /// builder.layer(TimeoutLayer::new(Duration::from_secs(30)));
454    /// ```
455    ///
456    /// Note that timeouts should be set using [`Server::timeout`]. `TimeoutLayer` is only used
457    /// here as an example.
458    ///
459    /// You can build more complex layers using [`ServiceBuilder`]. Those layers can include
460    /// [interceptors]:
461    ///
462    /// ```
463    /// # use tonic_rustls::Server;
464    /// # use tower_service::Service;
465    /// use tower::ServiceBuilder;
466    /// use std::time::Duration;
467    /// use tonic::{Request, Status, service::interceptor};
468    ///
469    /// fn auth_interceptor(request: Request<()>) -> Result<Request<()>, Status> {
470    ///     if valid_credentials(&request) {
471    ///         Ok(request)
472    ///     } else {
473    ///         Err(Status::unauthenticated("invalid credentials"))
474    ///     }
475    /// }
476    ///
477    /// fn valid_credentials(request: &Request<()>) -> bool {
478    ///     // ...
479    ///     # true
480    /// }
481    ///
482    /// fn some_other_interceptor(request: Request<()>) -> Result<Request<()>, Status> {
483    ///     Ok(request)
484    /// }
485    ///
486    /// let layer = ServiceBuilder::new()
487    ///     .load_shed()
488    ///     .timeout(Duration::from_secs(30))
489    ///     .layer(auth_interceptor)
490    ///     .layer(some_other_interceptor)
491    ///     .into_inner();
492    ///
493    /// Server::builder().layer(layer);
494    /// ```
495    ///
496    /// [Tower]: https://github.com/tower-rs/tower
497    /// [`Layer`]: tower::layer::Layer
498    /// [eco]: https://github.com/tower-rs
499    /// [`ServiceBuilder`]: tower::ServiceBuilder
500    /// [interceptors]: crate::service::Interceptor
501    pub fn layer<NewLayer>(self, new_layer: NewLayer) -> Server<Stack<NewLayer, L>> {
502        Server {
503            service_builder: self.service_builder.layer(new_layer),
504            trace_interceptor: self.trace_interceptor,
505            concurrency_limit: self.concurrency_limit,
506            timeout: self.timeout,
507            #[cfg(feature = "tls")]
508            tls: self.tls,
509            init_stream_window_size: self.init_stream_window_size,
510            init_connection_window_size: self.init_connection_window_size,
511            max_concurrent_streams: self.max_concurrent_streams,
512            tcp_keepalive: self.tcp_keepalive,
513            tcp_nodelay: self.tcp_nodelay,
514            http2_keepalive_interval: self.http2_keepalive_interval,
515            http2_keepalive_timeout: self.http2_keepalive_timeout,
516            http2_adaptive_window: self.http2_adaptive_window,
517            http2_max_pending_accept_reset_streams: self.http2_max_pending_accept_reset_streams,
518            http2_max_header_list_size: self.http2_max_header_list_size,
519            max_frame_size: self.max_frame_size,
520            accept_http1: self.accept_http1,
521            max_connection_age: self.max_connection_age,
522        }
523    }
524
525    pub(crate) async fn serve_with_shutdown<S, I, F, IO, IE, ResBody>(
526        self,
527        svc: S,
528        incoming: I,
529        signal: Option<F>,
530    ) -> Result<(), super::Error>
531    where
532        L: Layer<S>,
533        L::Service: Service<Request<Body>, Response = Response<ResBody>> + Clone + Send + 'static,
534        <<L as Layer<S>>::Service as Service<Request<Body>>>::Future: Send + 'static,
535        <<L as Layer<S>>::Service as Service<Request<Body>>>::Error:
536            Into<crate::BoxError> + Send + 'static,
537        I: Stream<Item = Result<IO, IE>>,
538        IO: AsyncRead + AsyncWrite + Connected + Unpin + Send + 'static,
539        IO::ConnectInfo: Clone + Send + Sync + 'static,
540        IE: Into<crate::BoxError>,
541        F: Future<Output = ()>,
542        ResBody: http_body::Body<Data = Bytes> + Send + 'static,
543        ResBody::Error: Into<crate::BoxError>,
544    {
545        let trace_interceptor = self.trace_interceptor.clone();
546        let concurrency_limit = self.concurrency_limit;
547        let init_connection_window_size = self.init_connection_window_size;
548        let init_stream_window_size = self.init_stream_window_size;
549        let max_concurrent_streams = self.max_concurrent_streams;
550        let timeout = self.timeout;
551        let max_header_list_size = self.http2_max_header_list_size;
552        let max_frame_size = self.max_frame_size;
553        let http2_only = !self.accept_http1;
554
555        let http2_keepalive_interval = self.http2_keepalive_interval;
556        let http2_keepalive_timeout = self
557            .http2_keepalive_timeout
558            .unwrap_or_else(|| Duration::new(DEFAULT_HTTP2_KEEPALIVE_TIMEOUT_SECS, 0));
559        let http2_adaptive_window = self.http2_adaptive_window;
560        let http2_max_pending_accept_reset_streams = self.http2_max_pending_accept_reset_streams;
561        let max_connection_age = self.max_connection_age;
562
563        let svc = self.service_builder.service(svc);
564
565        let incoming = incoming::tcp_incoming(
566            incoming,
567            #[cfg(feature = "tls")]
568            self.tls,
569        );
570        let mut svc = MakeSvc {
571            inner: svc,
572            concurrency_limit,
573            timeout,
574            trace_interceptor,
575            _io: PhantomData,
576        };
577
578        let server = {
579            let mut builder = ConnectionBuilder::new(TokioExecutor::new());
580
581            if http2_only {
582                builder = builder.http2_only();
583            }
584
585            builder
586                .http2()
587                .timer(TokioTimer::new())
588                .initial_connection_window_size(init_connection_window_size)
589                .initial_stream_window_size(init_stream_window_size)
590                .max_concurrent_streams(max_concurrent_streams)
591                .keep_alive_interval(http2_keepalive_interval)
592                .keep_alive_timeout(http2_keepalive_timeout)
593                .adaptive_window(http2_adaptive_window.unwrap_or_default())
594                .max_pending_accept_reset_streams(http2_max_pending_accept_reset_streams)
595                .max_frame_size(max_frame_size);
596
597            if let Some(max_header_list_size) = max_header_list_size {
598                builder.http2().max_header_list_size(max_header_list_size);
599            }
600
601            builder
602        };
603
604        let (signal_tx, signal_rx) = tokio::sync::watch::channel(());
605        let signal_tx = Arc::new(signal_tx);
606
607        let graceful = signal.is_some();
608        let mut sig = pin!(Fuse { inner: signal });
609        let mut incoming = pin!(incoming);
610
611        loop {
612            tokio::select! {
613                _ = &mut sig => {
614                    trace!("signal received, shutting down");
615                    break;
616                },
617                io = incoming.next() => {
618                    let io = match io {
619                        Some(Ok(io)) => io,
620                        Some(Err(e)) => {
621                            trace!("error accepting connection: {:#}", e);
622                            continue;
623                        },
624                        None => {
625                            break
626                        },
627                    };
628
629                    trace!("connection accepted");
630
631                    poll_fn(|cx| svc.poll_ready(cx))
632                        .await
633                        .map_err(super::Error::from_source)?;
634
635                    let req_svc = svc
636                        .call(&io)
637                        .await
638                        .map_err(super::Error::from_source)?;
639
640                    let hyper_io = TokioIo::new(io);
641                    let hyper_svc = TowerToHyperService::new(req_svc.map_request(|req: Request<Incoming>| req.map(Body::new)));
642
643                    serve_connection(hyper_io, hyper_svc, server.clone(), graceful.then(|| signal_rx.clone()), max_connection_age);
644                }
645            }
646        }
647
648        if graceful {
649            let _ = signal_tx.send(());
650            drop(signal_rx);
651            trace!(
652                "waiting for {} connections to close",
653                signal_tx.receiver_count()
654            );
655
656            // Wait for all connections to close
657            signal_tx.closed().await;
658        }
659
660        Ok(())
661    }
662}
663
664// This is moved to its own function as a way to get around
665// https://github.com/rust-lang/rust/issues/102211
666fn serve_connection<B, IO, S, E>(
667    hyper_io: IO,
668    hyper_svc: S,
669    builder: ConnectionBuilder<E>,
670    mut watcher: Option<tokio::sync::watch::Receiver<()>>,
671    max_connection_age: Option<Duration>,
672) where
673    B: http_body::Body + Send + 'static,
674    B::Data: Send,
675    B::Error: Into<Box<dyn std::error::Error + Send + Sync>> + Send + Sync,
676    IO: hyper::rt::Read + hyper::rt::Write + Unpin + Send + 'static,
677    S: HyperService<Request<Incoming>, Response = Response<B>> + Clone + Send + 'static,
678    S::Future: Send + 'static,
679    S::Error: Into<Box<dyn std::error::Error + Send + Sync>> + Send,
680    E: HttpServerConnExec<S::Future, B> + Send + Sync + 'static,
681{
682    tokio::spawn(async move {
683        {
684            let mut sig = pin!(Fuse {
685                inner: watcher.as_mut().map(|w| w.changed()),
686            });
687
688            let mut conn = pin!(builder.serve_connection(hyper_io, hyper_svc));
689
690            let sleep = sleep_or_pending(max_connection_age);
691            tokio::pin!(sleep);
692
693            loop {
694                tokio::select! {
695                    rv = &mut conn => {
696                        if let Err(err) = rv {
697                            debug!("failed serving connection: {:#}", err);
698                        }
699                        break;
700                    },
701                    _ = &mut sleep  => {
702                        conn.as_mut().graceful_shutdown();
703                        sleep.set(sleep_or_pending(None));
704                    },
705                    _ = &mut sig => {
706                        conn.as_mut().graceful_shutdown();
707                    }
708                }
709            }
710        }
711
712        drop(watcher);
713        trace!("connection closed");
714    });
715}
716
717async fn sleep_or_pending(wait_for: Option<Duration>) {
718    match wait_for {
719        Some(wait) => sleep(wait).await,
720        None => pending().await,
721    };
722}
723
724impl<L> Router<L> {
725    pub(crate) fn new(server: Server<L>, routes: Routes) -> Self {
726        Self { server, routes }
727    }
728}
729
730impl<L> Router<L> {
731    /// Add a new service to this router.
732    pub fn add_service<S>(mut self, svc: S) -> Self
733    where
734        S: Service<Request<Body>, Response = Response<Body>, Error = Infallible>
735            + NamedService
736            + Clone
737            + Send
738            + Sync
739            + 'static,
740        S::Future: Send + 'static,
741    {
742        self.routes = self.routes.add_service(svc);
743        self
744    }
745
746    /// Add a new optional service to this router.
747    ///
748    /// # Note
749    /// Even when the argument given is `None` this will capture *all* requests to this service name.
750    /// As a result, one cannot use this to toggle between two identically named implementations.
751    #[allow(clippy::type_complexity)]
752    pub fn add_optional_service<S>(mut self, svc: Option<S>) -> Self
753    where
754        S: Service<Request<Body>, Response = Response<Body>, Error = Infallible>
755            + NamedService
756            + Clone
757            + Send
758            + Sync
759            + 'static,
760        S::Future: Send + 'static,
761    {
762        if let Some(svc) = svc {
763            self.routes = self.routes.add_service(svc);
764        }
765        self
766    }
767
768    /// Consume this [`Server`] creating a future that will execute the server
769    /// on [tokio]'s default executor.
770    ///
771    /// [`Server`]: struct.Server.html
772    /// [tokio]: https://docs.rs/tokio
773    pub async fn serve<ResBody>(self, addr: SocketAddr) -> Result<(), super::Error>
774    where
775        L: Layer<Routes> + Clone,
776        L::Service: Service<Request<Body>, Response = Response<ResBody>> + Clone + Send + 'static,
777        <<L as Layer<Routes>>::Service as Service<Request<Body>>>::Future: Send + 'static,
778        <<L as Layer<Routes>>::Service as Service<Request<Body>>>::Error:
779            Into<crate::BoxError> + Send,
780        ResBody: http_body::Body<Data = Bytes> + Send + 'static,
781        ResBody::Error: Into<crate::BoxError>,
782    {
783        let incoming = TcpIncoming::new(addr, self.server.tcp_nodelay, self.server.tcp_keepalive)
784            .map_err(super::Error::from_source)?;
785        self.server
786            .serve_with_shutdown::<_, _, future::Ready<()>, _, _, ResBody>(
787                self.routes.prepare(),
788                incoming,
789                None,
790            )
791            .await
792    }
793
794    /// Consume this [`Server`] creating a future that will execute the server
795    /// on [tokio]'s default executor. And shutdown when the provided signal
796    /// is received.
797    ///
798    /// [`Server`]: struct.Server.html
799    /// [tokio]: https://docs.rs/tokio
800    pub async fn serve_with_shutdown<F: Future<Output = ()>, ResBody>(
801        self,
802        addr: SocketAddr,
803        signal: F,
804    ) -> Result<(), super::Error>
805    where
806        L: Layer<Routes>,
807        L::Service: Service<Request<Body>, Response = Response<ResBody>> + Clone + Send + 'static,
808        <<L as Layer<Routes>>::Service as Service<Request<Body>>>::Future: Send + 'static,
809        <<L as Layer<Routes>>::Service as Service<Request<Body>>>::Error:
810            Into<crate::BoxError> + Send,
811        ResBody: http_body::Body<Data = Bytes> + Send + 'static,
812        ResBody::Error: Into<crate::BoxError>,
813    {
814        let incoming = TcpIncoming::new(addr, self.server.tcp_nodelay, self.server.tcp_keepalive)
815            .map_err(super::Error::from_source)?;
816        self.server
817            .serve_with_shutdown(self.routes.prepare(), incoming, Some(signal))
818            .await
819    }
820
821    /// Consume this [`Server`] creating a future that will execute the server
822    /// on the provided incoming stream of `AsyncRead + AsyncWrite`.
823    ///
824    /// This method discards any provided [`Server`] TCP configuration.
825    ///
826    /// [`Server`]: struct.Server.html
827    pub async fn serve_with_incoming<I, IO, IE, ResBody>(
828        self,
829        incoming: I,
830    ) -> Result<(), super::Error>
831    where
832        I: Stream<Item = Result<IO, IE>>,
833        IO: AsyncRead + AsyncWrite + Connected + Unpin + Send + 'static,
834        IO::ConnectInfo: Clone + Send + Sync + 'static,
835        IE: Into<crate::BoxError>,
836        L: Layer<Routes>,
837        L::Service: Service<Request<Body>, Response = Response<ResBody>> + Clone + Send + 'static,
838        <<L as Layer<Routes>>::Service as Service<Request<Body>>>::Future: Send + 'static,
839        <<L as Layer<Routes>>::Service as Service<Request<Body>>>::Error:
840            Into<crate::BoxError> + Send,
841        ResBody: http_body::Body<Data = Bytes> + Send + 'static,
842        ResBody::Error: Into<crate::BoxError>,
843    {
844        self.server
845            .serve_with_shutdown::<_, _, future::Ready<()>, _, _, ResBody>(
846                self.routes.prepare(),
847                incoming,
848                None,
849            )
850            .await
851    }
852
853    /// Consume this [`Server`] creating a future that will execute the server
854    /// on the provided incoming stream of `AsyncRead + AsyncWrite`. Similar to
855    /// `serve_with_shutdown` this method will also take a signal future to
856    /// gracefully shutdown the server.
857    ///
858    /// This method discards any provided [`Server`] TCP configuration.
859    ///
860    /// [`Server`]: struct.Server.html
861    pub async fn serve_with_incoming_shutdown<I, IO, IE, F, ResBody>(
862        self,
863        incoming: I,
864        signal: F,
865    ) -> Result<(), super::Error>
866    where
867        I: Stream<Item = Result<IO, IE>>,
868        IO: AsyncRead + AsyncWrite + Connected + Unpin + Send + 'static,
869        IO::ConnectInfo: Clone + Send + Sync + 'static,
870        IE: Into<crate::BoxError>,
871        F: Future<Output = ()>,
872        L: Layer<Routes>,
873        L::Service: Service<Request<Body>, Response = Response<ResBody>> + Clone + Send + 'static,
874        <<L as Layer<Routes>>::Service as Service<Request<Body>>>::Future: Send + 'static,
875        <<L as Layer<Routes>>::Service as Service<Request<Body>>>::Error:
876            Into<crate::BoxError> + Send,
877        ResBody: http_body::Body<Data = Bytes> + Send + 'static,
878        ResBody::Error: Into<crate::BoxError>,
879    {
880        self.server
881            .serve_with_shutdown(self.routes.prepare(), incoming, Some(signal))
882            .await
883    }
884
885    /// Create a tower service out of a router.
886    pub fn into_service<ResBody>(self) -> L::Service
887    where
888        L: Layer<Routes>,
889    {
890        self.server.service_builder.service(self.routes.prepare())
891    }
892}
893
894impl<L> fmt::Debug for Server<L> {
895    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
896        f.debug_struct("Builder").finish()
897    }
898}
899
900#[derive(Clone)]
901struct Svc<S> {
902    inner: S,
903    trace_interceptor: Option<TraceInterceptor>,
904}
905
906impl<S, ResBody> Service<Request<Body>> for Svc<S>
907where
908    S: Service<Request<Body>, Response = Response<ResBody>>,
909    S::Error: Into<crate::BoxError>,
910    ResBody: http_body::Body<Data = Bytes> + Send + 'static,
911    ResBody::Error: Into<crate::BoxError>,
912{
913    type Response = Response<Body>;
914    type Error = crate::BoxError;
915    type Future = SvcFuture<S::Future>;
916
917    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
918        self.inner.poll_ready(cx).map_err(Into::into)
919    }
920
921    fn call(&mut self, mut req: Request<Body>) -> Self::Future {
922        let span = if let Some(trace_interceptor) = &self.trace_interceptor {
923            let (parts, body) = req.into_parts();
924            let bodyless_request = Request::from_parts(parts, ());
925
926            let span = trace_interceptor(&bodyless_request);
927
928            let (parts, _) = bodyless_request.into_parts();
929            req = Request::from_parts(parts, body);
930
931            span
932        } else {
933            tracing::Span::none()
934        };
935
936        SvcFuture {
937            inner: self.inner.call(req),
938            span,
939        }
940    }
941}
942
943#[pin_project]
944struct SvcFuture<F> {
945    #[pin]
946    inner: F,
947    span: tracing::Span,
948}
949
950impl<F, E, ResBody> Future for SvcFuture<F>
951where
952    F: Future<Output = Result<Response<ResBody>, E>>,
953    E: Into<crate::BoxError>,
954    ResBody: http_body::Body<Data = Bytes> + Send + 'static,
955    ResBody::Error: Into<crate::BoxError>,
956{
957    type Output = Result<Response<Body>, crate::BoxError>;
958
959    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
960        let this = self.project();
961        let _guard = this.span.enter();
962
963        let response: Response<ResBody> = ready!(this.inner.poll(cx)).map_err(Into::into)?;
964        let response = response.map(|body| Body::new(body.map_err(Into::into)));
965        Poll::Ready(Ok(response))
966    }
967}
968
969impl<S> fmt::Debug for Svc<S> {
970    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
971        f.debug_struct("Svc").finish()
972    }
973}
974
975#[derive(Clone)]
976struct MakeSvc<S, IO> {
977    concurrency_limit: Option<usize>,
978    timeout: Option<Duration>,
979    inner: S,
980    trace_interceptor: Option<TraceInterceptor>,
981    _io: PhantomData<fn() -> IO>,
982}
983
984impl<S, ResBody, IO> Service<&ServerIo<IO>> for MakeSvc<S, IO>
985where
986    IO: Connected,
987    S: Service<Request<Body>, Response = Response<ResBody>> + Clone + Send + 'static,
988    S::Future: Send + 'static,
989    S::Error: Into<crate::BoxError> + Send,
990    ResBody: http_body::Body<Data = Bytes> + Send + 'static,
991    ResBody::Error: Into<crate::BoxError>,
992{
993    type Response = BoxService;
994    type Error = crate::BoxError;
995    type Future = future::Ready<Result<Self::Response, Self::Error>>;
996
997    fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
998        Ok(()).into()
999    }
1000
1001    fn call(&mut self, io: &ServerIo<IO>) -> Self::Future {
1002        let conn_info = io.connect_info();
1003
1004        let svc = self.inner.clone();
1005        let concurrency_limit = self.concurrency_limit;
1006        let timeout = self.timeout;
1007        let trace_interceptor = self.trace_interceptor.clone();
1008
1009        let svc = ServiceBuilder::new()
1010            .layer_fn(RecoverError::new)
1011            .option_layer(concurrency_limit.map(ConcurrencyLimitLayer::new))
1012            .layer_fn(|s| GrpcTimeout::new(s, timeout))
1013            .service(svc);
1014
1015        let svc = ServiceBuilder::new()
1016            .layer(BoxCloneService::layer())
1017            .map_request(move |mut request: Request<Body>| {
1018                match &conn_info {
1019                    tower::util::Either::Left(inner) => {
1020                        request.extensions_mut().insert(inner.clone());
1021                    }
1022                    tower::util::Either::Right(inner) => {
1023                        #[cfg(feature = "tls")]
1024                        {
1025                            request.extensions_mut().insert(inner.clone());
1026                            request.extensions_mut().insert(inner.get_ref().clone());
1027                        }
1028
1029                        #[cfg(not(feature = "tls"))]
1030                        {
1031                            // just a type check to make sure we didn't forget to
1032                            // insert this into the extensions
1033                            let _: &() = inner;
1034                        }
1035                    }
1036                }
1037
1038                request
1039            })
1040            .service(Svc {
1041                inner: svc,
1042                trace_interceptor,
1043            });
1044
1045        future::ready(Ok(svc))
1046    }
1047}
1048
1049// From `futures-util` crate, borrowed since this is the only dependency tonic requires.
1050// LICENSE: MIT or Apache-2.0
1051// A future which only yields `Poll::Ready` once, and thereafter yields `Poll::Pending`.
1052#[pin_project]
1053struct Fuse<F> {
1054    #[pin]
1055    inner: Option<F>,
1056}
1057
1058impl<F> Future for Fuse<F>
1059where
1060    F: Future,
1061{
1062    type Output = F::Output;
1063
1064    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
1065        match self.as_mut().project().inner.as_pin_mut() {
1066            Some(fut) => fut.poll(cx).map(|output| {
1067                self.project().inner.set(None);
1068                output
1069            }),
1070            None => Poll::Pending,
1071        }
1072    }
1073}