mysten_network/
server.rs

1// Copyright (c) 2022, Mysten Labs, Inc.
2// SPDX-License-Identifier: Apache-2.0
3use crate::metrics::{
4    DefaultMetricsCallbackProvider, MetricsCallbackProvider, MetricsHandler,
5    GRPC_ENDPOINT_PATH_HEADER,
6};
7use crate::{
8    config::Config,
9    multiaddr::{parse_dns, parse_ip4, parse_ip6},
10};
11use eyre::{eyre, Result};
12use futures::FutureExt;
13use multiaddr::{Multiaddr, Protocol};
14use std::task::{Context, Poll};
15use std::{convert::Infallible, net::SocketAddr};
16use tokio::net::{TcpListener, ToSocketAddrs};
17use tokio_stream::wrappers::TcpListenerStream;
18use tonic::codegen::http::HeaderValue;
19use tonic::{
20    body::BoxBody,
21    codegen::{
22        http::{Request, Response},
23        BoxFuture,
24    },
25    transport::{server::Router, Body, NamedService},
26};
27use tower::{
28    layer::util::{Identity, Stack},
29    limit::GlobalConcurrencyLimitLayer,
30    load_shed::LoadShedLayer,
31    util::Either,
32    Layer, Service, ServiceBuilder,
33};
34use tower_http::classify::{GrpcErrorsAsFailures, SharedClassifier};
35use tower_http::propagate_header::PropagateHeaderLayer;
36use tower_http::set_header::SetRequestHeaderLayer;
37use tower_http::trace::{DefaultMakeSpan, DefaultOnBodyChunk, DefaultOnEos, TraceLayer};
38
39pub struct ServerBuilder<M: MetricsCallbackProvider = DefaultMetricsCallbackProvider> {
40    router: Router<WrapperService<M>>,
41    health_reporter: tonic_health::server::HealthReporter,
42}
43
44type AddPathToHeaderFunction = fn(&Request<Body>) -> Option<HeaderValue>;
45
46type WrapperService<M> = Stack<
47    Stack<
48        PropagateHeaderLayer,
49        Stack<
50            TraceLayer<
51                SharedClassifier<GrpcErrorsAsFailures>,
52                DefaultMakeSpan,
53                MetricsHandler<M>,
54                MetricsHandler<M>,
55                DefaultOnBodyChunk,
56                DefaultOnEos,
57                MetricsHandler<M>,
58            >,
59            Stack<
60                SetRequestHeaderLayer<AddPathToHeaderFunction>,
61                Stack<
62                    RequestLifetimeLayer<M>,
63                    Stack<
64                        Either<LoadShedLayer, Identity>,
65                        Stack<Either<GlobalConcurrencyLimitLayer, Identity>, Identity>,
66                    >,
67                >,
68            >,
69        >,
70    >,
71    Identity,
72>;
73
74impl<M: MetricsCallbackProvider> ServerBuilder<M> {
75    pub fn from_config(config: &Config, metrics_provider: M) -> Self {
76        let mut builder = tonic::transport::server::Server::builder();
77
78        if let Some(limit) = config.concurrency_limit_per_connection {
79            builder = builder.concurrency_limit_per_connection(limit);
80        }
81
82        if let Some(timeout) = config.request_timeout {
83            builder = builder.timeout(timeout);
84        }
85
86        if let Some(tcp_nodelay) = config.tcp_nodelay {
87            builder = builder.tcp_nodelay(tcp_nodelay);
88        }
89
90        let load_shed = config
91            .load_shed
92            .unwrap_or_default()
93            .then_some(tower::load_shed::LoadShedLayer::new());
94
95        let metrics = MetricsHandler::new(metrics_provider.clone());
96
97        let request_metrics = TraceLayer::new_for_grpc()
98            .on_request(metrics.clone())
99            .on_response(metrics.clone())
100            .on_failure(metrics);
101
102        let global_concurrency_limit = config
103            .global_concurrency_limit
104            .map(tower::limit::GlobalConcurrencyLimitLayer::new);
105
106        fn add_path_to_request_header(request: &Request<Body>) -> Option<HeaderValue> {
107            let path = request.uri().path();
108            Some(HeaderValue::from_str(path).unwrap())
109        }
110
111        let layer = ServiceBuilder::new()
112            .option_layer(global_concurrency_limit)
113            .option_layer(load_shed)
114            .layer(RequestLifetimeLayer { metrics_provider })
115            .layer(SetRequestHeaderLayer::overriding(
116                GRPC_ENDPOINT_PATH_HEADER.clone(),
117                add_path_to_request_header as AddPathToHeaderFunction,
118            ))
119            .layer(request_metrics)
120            .layer(PropagateHeaderLayer::new(GRPC_ENDPOINT_PATH_HEADER.clone()))
121            .into_inner();
122
123        let (health_reporter, health_service) = tonic_health::server::health_reporter();
124        let router = builder
125            .initial_stream_window_size(config.http2_initial_stream_window_size)
126            .initial_connection_window_size(config.http2_initial_connection_window_size)
127            .http2_keepalive_interval(config.http2_keepalive_interval)
128            .http2_keepalive_timeout(config.http2_keepalive_timeout)
129            .max_concurrent_streams(config.http2_max_concurrent_streams)
130            .tcp_keepalive(config.tcp_keepalive)
131            .layer(layer)
132            .add_service(health_service);
133
134        Self {
135            router,
136            health_reporter,
137        }
138    }
139
140    pub fn health_reporter(&self) -> tonic_health::server::HealthReporter {
141        self.health_reporter.clone()
142    }
143
144    /// Add a new service to this Server.
145    pub fn add_service<S>(mut self, svc: S) -> Self
146    where
147        S: Service<Request<Body>, Response = Response<BoxBody>, Error = Infallible>
148            + NamedService
149            + Clone
150            + Send
151            + 'static,
152        S::Future: Send + 'static,
153    {
154        self.router = self.router.add_service(svc);
155        self
156    }
157
158    pub async fn bind(self, addr: &Multiaddr) -> Result<Server> {
159        let mut iter = addr.iter();
160
161        let (tx_cancellation, rx_cancellation) = tokio::sync::oneshot::channel();
162        let rx_cancellation = rx_cancellation.map(|_| ());
163        let (local_addr, server): (Multiaddr, BoxFuture<(), tonic::transport::Error>) =
164            match iter.next().ok_or_else(|| eyre!("malformed addr"))? {
165                Protocol::Dns(_) => {
166                    let (dns_name, tcp_port, _http_or_https) = parse_dns(addr)?;
167                    let (local_addr, incoming) =
168                        tcp_listener_and_update_multiaddr(addr, (dns_name.as_ref(), tcp_port))
169                            .await?;
170                    let server = Box::pin(
171                        self.router
172                            .serve_with_incoming_shutdown(incoming, rx_cancellation),
173                    );
174                    (local_addr, server)
175                }
176                Protocol::Ip4(_) => {
177                    let (socket_addr, _http_or_https) = parse_ip4(addr)?;
178                    let (local_addr, incoming) =
179                        tcp_listener_and_update_multiaddr(addr, socket_addr).await?;
180                    let server = Box::pin(
181                        self.router
182                            .serve_with_incoming_shutdown(incoming, rx_cancellation),
183                    );
184                    (local_addr, server)
185                }
186                Protocol::Ip6(_) => {
187                    let (socket_addr, _http_or_https) = parse_ip6(addr)?;
188                    let (local_addr, incoming) =
189                        tcp_listener_and_update_multiaddr(addr, socket_addr).await?;
190                    let server = Box::pin(
191                        self.router
192                            .serve_with_incoming_shutdown(incoming, rx_cancellation),
193                    );
194                    (local_addr, server)
195                }
196                // Protocol::Memory(_) => todo!(),
197                #[cfg(unix)]
198                Protocol::Unix(_) => {
199                    let (path, _http_or_https) = crate::multiaddr::parse_unix(addr)?;
200                    let uds = tokio::net::UnixListener::bind(path.as_ref())?;
201                    let uds_stream = tokio_stream::wrappers::UnixListenerStream::new(uds);
202                    let local_addr = addr.to_owned();
203                    let server = Box::pin(
204                        self.router
205                            .serve_with_incoming_shutdown(uds_stream, rx_cancellation),
206                    );
207                    (local_addr, server)
208                }
209                unsupported => return Err(eyre!("unsupported protocol {unsupported}")),
210            };
211
212        Ok(Server {
213            server,
214            cancel_handle: Some(tx_cancellation),
215            local_addr,
216            health_reporter: self.health_reporter,
217        })
218    }
219}
220
221async fn tcp_listener_and_update_multiaddr<T: ToSocketAddrs>(
222    address: &Multiaddr,
223    socket_addr: T,
224) -> Result<(Multiaddr, TcpListenerStream)> {
225    let (local_addr, incoming) = tcp_listener(socket_addr).await?;
226    let local_addr = update_tcp_port_in_multiaddr(address, local_addr.port());
227    Ok((local_addr, incoming))
228}
229
230async fn tcp_listener<T: ToSocketAddrs>(address: T) -> Result<(SocketAddr, TcpListenerStream)> {
231    let listener = TcpListener::bind(address).await?;
232    let local_addr = listener.local_addr()?;
233    let incoming = TcpListenerStream::new(listener);
234    Ok((local_addr, incoming))
235}
236
237pub struct Server {
238    server: BoxFuture<(), tonic::transport::Error>,
239    cancel_handle: Option<tokio::sync::oneshot::Sender<()>>,
240    local_addr: Multiaddr,
241    health_reporter: tonic_health::server::HealthReporter,
242}
243
244impl Server {
245    pub async fn serve(self) -> Result<(), tonic::transport::Error> {
246        self.server.await
247    }
248
249    pub fn local_addr(&self) -> &Multiaddr {
250        &self.local_addr
251    }
252
253    pub fn health_reporter(&self) -> tonic_health::server::HealthReporter {
254        self.health_reporter.clone()
255    }
256
257    pub fn take_cancel_handle(&mut self) -> Option<tokio::sync::oneshot::Sender<()>> {
258        self.cancel_handle.take()
259    }
260}
261
262fn update_tcp_port_in_multiaddr(addr: &Multiaddr, port: u16) -> Multiaddr {
263    addr.replace(1, |protocol| {
264        if let Protocol::Tcp(_) = protocol {
265            Some(Protocol::Tcp(port))
266        } else {
267            panic!("expected tcp protocol at index 1");
268        }
269    })
270    .expect("tcp protocol at index 1")
271}
272
273#[cfg(test)]
274mod test {
275    use crate::config::Config;
276    use crate::metrics::MetricsCallbackProvider;
277    use multiaddr::multiaddr;
278    use multiaddr::Multiaddr;
279    use std::ops::Deref;
280    use std::sync::{Arc, Mutex};
281    use std::time::Duration;
282    use tonic::Code;
283    use tonic_health::proto::health_client::HealthClient;
284    use tonic_health::proto::HealthCheckRequest;
285
286    #[test]
287    fn document_multiaddr_limitation_for_unix_protocol() {
288        // You can construct a multiaddr by hand (ie binary format) just fine
289        let path = "/tmp/foo";
290        let addr = multiaddr!(Unix(path), Http);
291
292        // But it doesn't round-trip in the human readable format
293        let s = addr.to_string();
294        assert!(s.parse::<Multiaddr>().is_err());
295    }
296
297    #[tokio::test]
298    async fn test_metrics_layer_successful() {
299        #[derive(Clone)]
300        struct Metrics {
301            /// a flag to figure out whether the
302            /// on_request method has been called.
303            metrics_called: Arc<Mutex<bool>>,
304        }
305
306        impl MetricsCallbackProvider for Metrics {
307            fn on_request(&self, path: String) {
308                assert_eq!(path, "/grpc.health.v1.Health/Check");
309            }
310
311            fn on_response(
312                &self,
313                path: String,
314                _latency: Duration,
315                status: u16,
316                grpc_status_code: Code,
317            ) {
318                assert_eq!(path, "/grpc.health.v1.Health/Check");
319                assert_eq!(status, 200);
320                assert_eq!(grpc_status_code, Code::Ok);
321                let mut m = self.metrics_called.lock().unwrap();
322                *m = true
323            }
324        }
325
326        let metrics = Metrics {
327            metrics_called: Arc::new(Mutex::new(false)),
328        };
329
330        let address: Multiaddr = "/ip4/127.0.0.1/tcp/0/http".parse().unwrap();
331        let config = Config::new();
332
333        let mut server = config
334            .server_builder_with_metrics(metrics.clone())
335            .bind(&address)
336            .await
337            .unwrap();
338
339        let address = server.local_addr().to_owned();
340        let cancel_handle = server.take_cancel_handle().unwrap();
341        let server_handle = tokio::spawn(server.serve());
342        let channel = config.connect(&address).await.unwrap();
343        let mut client = HealthClient::new(channel);
344
345        client
346            .check(HealthCheckRequest {
347                service: "".to_owned(),
348            })
349            .await
350            .unwrap();
351
352        cancel_handle.send(()).unwrap();
353        server_handle.await.unwrap().unwrap();
354
355        assert!(metrics.metrics_called.lock().unwrap().deref());
356    }
357
358    #[tokio::test]
359    async fn test_metrics_layer_error() {
360        #[derive(Clone)]
361        struct Metrics {
362            /// a flag to figure out whether the
363            /// on_request method has been called.
364            metrics_called: Arc<Mutex<bool>>,
365        }
366
367        impl MetricsCallbackProvider for Metrics {
368            fn on_request(&self, path: String) {
369                assert_eq!(path, "/grpc.health.v1.Health/Check");
370            }
371
372            fn on_response(
373                &self,
374                path: String,
375                _latency: Duration,
376                status: u16,
377                grpc_status_code: Code,
378            ) {
379                assert_eq!(path, "/grpc.health.v1.Health/Check");
380                assert_eq!(status, 200);
381                // According to https://github.com/grpc/grpc/blob/master/doc/statuscodes.md#status-codes-and-their-use-in-grpc
382                // code 5 is not_found , which is what we expect to get in this case
383                assert_eq!(grpc_status_code, Code::NotFound);
384                let mut m = self.metrics_called.lock().unwrap();
385                *m = true
386            }
387        }
388
389        let metrics = Metrics {
390            metrics_called: Arc::new(Mutex::new(false)),
391        };
392
393        let address: Multiaddr = "/ip4/127.0.0.1/tcp/0/http".parse().unwrap();
394        let config = Config::new();
395
396        let mut server = config
397            .server_builder_with_metrics(metrics.clone())
398            .bind(&address)
399            .await
400            .unwrap();
401
402        let address = server.local_addr().to_owned();
403        let cancel_handle = server.take_cancel_handle().unwrap();
404        let server_handle = tokio::spawn(server.serve());
405        let channel = config.connect(&address).await.unwrap();
406        let mut client = HealthClient::new(channel);
407
408        // Call the healthcheck for a service that doesn't exist
409        // that should give us back an error with code 5 (not_found)
410        // https://github.com/grpc/grpc/blob/master/doc/statuscodes.md#status-codes-and-their-use-in-grpc
411        let _ = client
412            .check(HealthCheckRequest {
413                service: "non-existing-service".to_owned(),
414            })
415            .await;
416
417        cancel_handle.send(()).unwrap();
418        server_handle.await.unwrap().unwrap();
419
420        assert!(metrics.metrics_called.lock().unwrap().deref());
421    }
422
423    async fn test_multiaddr(address: Multiaddr) {
424        let config = Config::new();
425        let mut server = config.server_builder().bind(&address).await.unwrap();
426        let address = server.local_addr().to_owned();
427        let cancel_handle = server.take_cancel_handle().unwrap();
428        let server_handle = tokio::spawn(server.serve());
429        let channel = config.connect(&address).await.unwrap();
430        let mut client = HealthClient::new(channel);
431
432        client
433            .check(HealthCheckRequest {
434                service: "".to_owned(),
435            })
436            .await
437            .unwrap();
438
439        cancel_handle.send(()).unwrap();
440        server_handle.await.unwrap().unwrap();
441    }
442
443    #[tokio::test]
444    async fn dns() {
445        let address: Multiaddr = "/dns/localhost/tcp/0/http".parse().unwrap();
446        test_multiaddr(address).await;
447    }
448
449    #[tokio::test]
450    async fn ip4() {
451        let address: Multiaddr = "/ip4/127.0.0.1/tcp/0/http".parse().unwrap();
452        test_multiaddr(address).await;
453    }
454
455    #[tokio::test]
456    async fn ip6() {
457        let address: Multiaddr = "/ip6/::1/tcp/0/http".parse().unwrap();
458        test_multiaddr(address).await;
459    }
460
461    #[cfg(unix)]
462    #[tokio::test]
463    async fn unix() {
464        // Note that this only works when constructing a multiaddr by hand and not via the
465        // human-readable format
466        let path = "unix-domain-socket";
467        let address = multiaddr!(Unix(path), Http);
468        test_multiaddr(address).await;
469        std::fs::remove_file(path).unwrap();
470    }
471
472    #[should_panic]
473    #[tokio::test]
474    async fn missing_http_protocol() {
475        let address: Multiaddr = "/dns/localhost/tcp/0".parse().unwrap();
476        test_multiaddr(address).await;
477    }
478}
479
480#[derive(Clone)]
481struct RequestLifetimeLayer<M: MetricsCallbackProvider> {
482    metrics_provider: M,
483}
484
485impl<M: MetricsCallbackProvider, S> Layer<S> for RequestLifetimeLayer<M> {
486    type Service = RequestLifetime<M, S>;
487
488    fn layer(&self, inner: S) -> Self::Service {
489        RequestLifetime {
490            inner,
491            metrics_provider: self.metrics_provider.clone(),
492            path: None,
493        }
494    }
495}
496
497#[derive(Clone)]
498struct RequestLifetime<M: MetricsCallbackProvider, S> {
499    inner: S,
500    metrics_provider: M,
501    path: Option<String>,
502}
503
504impl<M: MetricsCallbackProvider, S, RequestBody> Service<Request<RequestBody>>
505    for RequestLifetime<M, S>
506where
507    S: Service<Request<RequestBody>>,
508{
509    type Response = S::Response;
510    type Error = S::Error;
511    type Future = S::Future;
512
513    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
514        self.inner.poll_ready(cx)
515    }
516
517    fn call(&mut self, request: Request<RequestBody>) -> Self::Future {
518        if self.path.is_none() {
519            let path = request.uri().path().to_string();
520            self.metrics_provider.on_start(&path);
521            self.path = Some(path);
522        }
523        self.inner.call(request)
524    }
525}
526
527impl<M: MetricsCallbackProvider, S> Drop for RequestLifetime<M, S> {
528    fn drop(&mut self) {
529        if let Some(path) = &self.path {
530            self.metrics_provider.on_drop(path)
531        }
532    }
533}