1use crate::metrics::{
4 DefaultMetricsCallbackProvider, MetricsCallbackProvider, MetricsHandler,
5 GRPC_ENDPOINT_PATH_HEADER,
6};
7use crate::{
8 config::Config,
9 multiaddr::{parse_dns, parse_ip4, parse_ip6},
10};
11use eyre::{eyre, Result};
12use futures::FutureExt;
13use multiaddr::{Multiaddr, Protocol};
14use std::task::{Context, Poll};
15use std::{convert::Infallible, net::SocketAddr};
16use tokio::net::{TcpListener, ToSocketAddrs};
17use tokio_stream::wrappers::TcpListenerStream;
18use tonic::codegen::http::HeaderValue;
19use tonic::{
20 body::BoxBody,
21 codegen::{
22 http::{Request, Response},
23 BoxFuture,
24 },
25 transport::{server::Router, Body, NamedService},
26};
27use tower::{
28 layer::util::{Identity, Stack},
29 limit::GlobalConcurrencyLimitLayer,
30 load_shed::LoadShedLayer,
31 util::Either,
32 Layer, Service, ServiceBuilder,
33};
34use tower_http::classify::{GrpcErrorsAsFailures, SharedClassifier};
35use tower_http::propagate_header::PropagateHeaderLayer;
36use tower_http::set_header::SetRequestHeaderLayer;
37use tower_http::trace::{DefaultMakeSpan, DefaultOnBodyChunk, DefaultOnEos, TraceLayer};
38
39pub struct ServerBuilder<M: MetricsCallbackProvider = DefaultMetricsCallbackProvider> {
40 router: Router<WrapperService<M>>,
41 health_reporter: tonic_health::server::HealthReporter,
42}
43
44type AddPathToHeaderFunction = fn(&Request<Body>) -> Option<HeaderValue>;
45
46type WrapperService<M> = Stack<
47 Stack<
48 PropagateHeaderLayer,
49 Stack<
50 TraceLayer<
51 SharedClassifier<GrpcErrorsAsFailures>,
52 DefaultMakeSpan,
53 MetricsHandler<M>,
54 MetricsHandler<M>,
55 DefaultOnBodyChunk,
56 DefaultOnEos,
57 MetricsHandler<M>,
58 >,
59 Stack<
60 SetRequestHeaderLayer<AddPathToHeaderFunction>,
61 Stack<
62 RequestLifetimeLayer<M>,
63 Stack<
64 Either<LoadShedLayer, Identity>,
65 Stack<Either<GlobalConcurrencyLimitLayer, Identity>, Identity>,
66 >,
67 >,
68 >,
69 >,
70 >,
71 Identity,
72>;
73
74impl<M: MetricsCallbackProvider> ServerBuilder<M> {
75 pub fn from_config(config: &Config, metrics_provider: M) -> Self {
76 let mut builder = tonic::transport::server::Server::builder();
77
78 if let Some(limit) = config.concurrency_limit_per_connection {
79 builder = builder.concurrency_limit_per_connection(limit);
80 }
81
82 if let Some(timeout) = config.request_timeout {
83 builder = builder.timeout(timeout);
84 }
85
86 if let Some(tcp_nodelay) = config.tcp_nodelay {
87 builder = builder.tcp_nodelay(tcp_nodelay);
88 }
89
90 let load_shed = config
91 .load_shed
92 .unwrap_or_default()
93 .then_some(tower::load_shed::LoadShedLayer::new());
94
95 let metrics = MetricsHandler::new(metrics_provider.clone());
96
97 let request_metrics = TraceLayer::new_for_grpc()
98 .on_request(metrics.clone())
99 .on_response(metrics.clone())
100 .on_failure(metrics);
101
102 let global_concurrency_limit = config
103 .global_concurrency_limit
104 .map(tower::limit::GlobalConcurrencyLimitLayer::new);
105
106 fn add_path_to_request_header(request: &Request<Body>) -> Option<HeaderValue> {
107 let path = request.uri().path();
108 Some(HeaderValue::from_str(path).unwrap())
109 }
110
111 let layer = ServiceBuilder::new()
112 .option_layer(global_concurrency_limit)
113 .option_layer(load_shed)
114 .layer(RequestLifetimeLayer { metrics_provider })
115 .layer(SetRequestHeaderLayer::overriding(
116 GRPC_ENDPOINT_PATH_HEADER.clone(),
117 add_path_to_request_header as AddPathToHeaderFunction,
118 ))
119 .layer(request_metrics)
120 .layer(PropagateHeaderLayer::new(GRPC_ENDPOINT_PATH_HEADER.clone()))
121 .into_inner();
122
123 let (health_reporter, health_service) = tonic_health::server::health_reporter();
124 let router = builder
125 .initial_stream_window_size(config.http2_initial_stream_window_size)
126 .initial_connection_window_size(config.http2_initial_connection_window_size)
127 .http2_keepalive_interval(config.http2_keepalive_interval)
128 .http2_keepalive_timeout(config.http2_keepalive_timeout)
129 .max_concurrent_streams(config.http2_max_concurrent_streams)
130 .tcp_keepalive(config.tcp_keepalive)
131 .layer(layer)
132 .add_service(health_service);
133
134 Self {
135 router,
136 health_reporter,
137 }
138 }
139
140 pub fn health_reporter(&self) -> tonic_health::server::HealthReporter {
141 self.health_reporter.clone()
142 }
143
144 pub fn add_service<S>(mut self, svc: S) -> Self
146 where
147 S: Service<Request<Body>, Response = Response<BoxBody>, Error = Infallible>
148 + NamedService
149 + Clone
150 + Send
151 + 'static,
152 S::Future: Send + 'static,
153 {
154 self.router = self.router.add_service(svc);
155 self
156 }
157
158 pub async fn bind(self, addr: &Multiaddr) -> Result<Server> {
159 let mut iter = addr.iter();
160
161 let (tx_cancellation, rx_cancellation) = tokio::sync::oneshot::channel();
162 let rx_cancellation = rx_cancellation.map(|_| ());
163 let (local_addr, server): (Multiaddr, BoxFuture<(), tonic::transport::Error>) =
164 match iter.next().ok_or_else(|| eyre!("malformed addr"))? {
165 Protocol::Dns(_) => {
166 let (dns_name, tcp_port, _http_or_https) = parse_dns(addr)?;
167 let (local_addr, incoming) =
168 tcp_listener_and_update_multiaddr(addr, (dns_name.as_ref(), tcp_port))
169 .await?;
170 let server = Box::pin(
171 self.router
172 .serve_with_incoming_shutdown(incoming, rx_cancellation),
173 );
174 (local_addr, server)
175 }
176 Protocol::Ip4(_) => {
177 let (socket_addr, _http_or_https) = parse_ip4(addr)?;
178 let (local_addr, incoming) =
179 tcp_listener_and_update_multiaddr(addr, socket_addr).await?;
180 let server = Box::pin(
181 self.router
182 .serve_with_incoming_shutdown(incoming, rx_cancellation),
183 );
184 (local_addr, server)
185 }
186 Protocol::Ip6(_) => {
187 let (socket_addr, _http_or_https) = parse_ip6(addr)?;
188 let (local_addr, incoming) =
189 tcp_listener_and_update_multiaddr(addr, socket_addr).await?;
190 let server = Box::pin(
191 self.router
192 .serve_with_incoming_shutdown(incoming, rx_cancellation),
193 );
194 (local_addr, server)
195 }
196 #[cfg(unix)]
198 Protocol::Unix(_) => {
199 let (path, _http_or_https) = crate::multiaddr::parse_unix(addr)?;
200 let uds = tokio::net::UnixListener::bind(path.as_ref())?;
201 let uds_stream = tokio_stream::wrappers::UnixListenerStream::new(uds);
202 let local_addr = addr.to_owned();
203 let server = Box::pin(
204 self.router
205 .serve_with_incoming_shutdown(uds_stream, rx_cancellation),
206 );
207 (local_addr, server)
208 }
209 unsupported => return Err(eyre!("unsupported protocol {unsupported}")),
210 };
211
212 Ok(Server {
213 server,
214 cancel_handle: Some(tx_cancellation),
215 local_addr,
216 health_reporter: self.health_reporter,
217 })
218 }
219}
220
221async fn tcp_listener_and_update_multiaddr<T: ToSocketAddrs>(
222 address: &Multiaddr,
223 socket_addr: T,
224) -> Result<(Multiaddr, TcpListenerStream)> {
225 let (local_addr, incoming) = tcp_listener(socket_addr).await?;
226 let local_addr = update_tcp_port_in_multiaddr(address, local_addr.port());
227 Ok((local_addr, incoming))
228}
229
230async fn tcp_listener<T: ToSocketAddrs>(address: T) -> Result<(SocketAddr, TcpListenerStream)> {
231 let listener = TcpListener::bind(address).await?;
232 let local_addr = listener.local_addr()?;
233 let incoming = TcpListenerStream::new(listener);
234 Ok((local_addr, incoming))
235}
236
237pub struct Server {
238 server: BoxFuture<(), tonic::transport::Error>,
239 cancel_handle: Option<tokio::sync::oneshot::Sender<()>>,
240 local_addr: Multiaddr,
241 health_reporter: tonic_health::server::HealthReporter,
242}
243
244impl Server {
245 pub async fn serve(self) -> Result<(), tonic::transport::Error> {
246 self.server.await
247 }
248
249 pub fn local_addr(&self) -> &Multiaddr {
250 &self.local_addr
251 }
252
253 pub fn health_reporter(&self) -> tonic_health::server::HealthReporter {
254 self.health_reporter.clone()
255 }
256
257 pub fn take_cancel_handle(&mut self) -> Option<tokio::sync::oneshot::Sender<()>> {
258 self.cancel_handle.take()
259 }
260}
261
262fn update_tcp_port_in_multiaddr(addr: &Multiaddr, port: u16) -> Multiaddr {
263 addr.replace(1, |protocol| {
264 if let Protocol::Tcp(_) = protocol {
265 Some(Protocol::Tcp(port))
266 } else {
267 panic!("expected tcp protocol at index 1");
268 }
269 })
270 .expect("tcp protocol at index 1")
271}
272
273#[cfg(test)]
274mod test {
275 use crate::config::Config;
276 use crate::metrics::MetricsCallbackProvider;
277 use multiaddr::multiaddr;
278 use multiaddr::Multiaddr;
279 use std::ops::Deref;
280 use std::sync::{Arc, Mutex};
281 use std::time::Duration;
282 use tonic::Code;
283 use tonic_health::proto::health_client::HealthClient;
284 use tonic_health::proto::HealthCheckRequest;
285
286 #[test]
287 fn document_multiaddr_limitation_for_unix_protocol() {
288 let path = "/tmp/foo";
290 let addr = multiaddr!(Unix(path), Http);
291
292 let s = addr.to_string();
294 assert!(s.parse::<Multiaddr>().is_err());
295 }
296
297 #[tokio::test]
298 async fn test_metrics_layer_successful() {
299 #[derive(Clone)]
300 struct Metrics {
301 metrics_called: Arc<Mutex<bool>>,
304 }
305
306 impl MetricsCallbackProvider for Metrics {
307 fn on_request(&self, path: String) {
308 assert_eq!(path, "/grpc.health.v1.Health/Check");
309 }
310
311 fn on_response(
312 &self,
313 path: String,
314 _latency: Duration,
315 status: u16,
316 grpc_status_code: Code,
317 ) {
318 assert_eq!(path, "/grpc.health.v1.Health/Check");
319 assert_eq!(status, 200);
320 assert_eq!(grpc_status_code, Code::Ok);
321 let mut m = self.metrics_called.lock().unwrap();
322 *m = true
323 }
324 }
325
326 let metrics = Metrics {
327 metrics_called: Arc::new(Mutex::new(false)),
328 };
329
330 let address: Multiaddr = "/ip4/127.0.0.1/tcp/0/http".parse().unwrap();
331 let config = Config::new();
332
333 let mut server = config
334 .server_builder_with_metrics(metrics.clone())
335 .bind(&address)
336 .await
337 .unwrap();
338
339 let address = server.local_addr().to_owned();
340 let cancel_handle = server.take_cancel_handle().unwrap();
341 let server_handle = tokio::spawn(server.serve());
342 let channel = config.connect(&address).await.unwrap();
343 let mut client = HealthClient::new(channel);
344
345 client
346 .check(HealthCheckRequest {
347 service: "".to_owned(),
348 })
349 .await
350 .unwrap();
351
352 cancel_handle.send(()).unwrap();
353 server_handle.await.unwrap().unwrap();
354
355 assert!(metrics.metrics_called.lock().unwrap().deref());
356 }
357
358 #[tokio::test]
359 async fn test_metrics_layer_error() {
360 #[derive(Clone)]
361 struct Metrics {
362 metrics_called: Arc<Mutex<bool>>,
365 }
366
367 impl MetricsCallbackProvider for Metrics {
368 fn on_request(&self, path: String) {
369 assert_eq!(path, "/grpc.health.v1.Health/Check");
370 }
371
372 fn on_response(
373 &self,
374 path: String,
375 _latency: Duration,
376 status: u16,
377 grpc_status_code: Code,
378 ) {
379 assert_eq!(path, "/grpc.health.v1.Health/Check");
380 assert_eq!(status, 200);
381 assert_eq!(grpc_status_code, Code::NotFound);
384 let mut m = self.metrics_called.lock().unwrap();
385 *m = true
386 }
387 }
388
389 let metrics = Metrics {
390 metrics_called: Arc::new(Mutex::new(false)),
391 };
392
393 let address: Multiaddr = "/ip4/127.0.0.1/tcp/0/http".parse().unwrap();
394 let config = Config::new();
395
396 let mut server = config
397 .server_builder_with_metrics(metrics.clone())
398 .bind(&address)
399 .await
400 .unwrap();
401
402 let address = server.local_addr().to_owned();
403 let cancel_handle = server.take_cancel_handle().unwrap();
404 let server_handle = tokio::spawn(server.serve());
405 let channel = config.connect(&address).await.unwrap();
406 let mut client = HealthClient::new(channel);
407
408 let _ = client
412 .check(HealthCheckRequest {
413 service: "non-existing-service".to_owned(),
414 })
415 .await;
416
417 cancel_handle.send(()).unwrap();
418 server_handle.await.unwrap().unwrap();
419
420 assert!(metrics.metrics_called.lock().unwrap().deref());
421 }
422
423 async fn test_multiaddr(address: Multiaddr) {
424 let config = Config::new();
425 let mut server = config.server_builder().bind(&address).await.unwrap();
426 let address = server.local_addr().to_owned();
427 let cancel_handle = server.take_cancel_handle().unwrap();
428 let server_handle = tokio::spawn(server.serve());
429 let channel = config.connect(&address).await.unwrap();
430 let mut client = HealthClient::new(channel);
431
432 client
433 .check(HealthCheckRequest {
434 service: "".to_owned(),
435 })
436 .await
437 .unwrap();
438
439 cancel_handle.send(()).unwrap();
440 server_handle.await.unwrap().unwrap();
441 }
442
443 #[tokio::test]
444 async fn dns() {
445 let address: Multiaddr = "/dns/localhost/tcp/0/http".parse().unwrap();
446 test_multiaddr(address).await;
447 }
448
449 #[tokio::test]
450 async fn ip4() {
451 let address: Multiaddr = "/ip4/127.0.0.1/tcp/0/http".parse().unwrap();
452 test_multiaddr(address).await;
453 }
454
455 #[tokio::test]
456 async fn ip6() {
457 let address: Multiaddr = "/ip6/::1/tcp/0/http".parse().unwrap();
458 test_multiaddr(address).await;
459 }
460
461 #[cfg(unix)]
462 #[tokio::test]
463 async fn unix() {
464 let path = "unix-domain-socket";
467 let address = multiaddr!(Unix(path), Http);
468 test_multiaddr(address).await;
469 std::fs::remove_file(path).unwrap();
470 }
471
472 #[should_panic]
473 #[tokio::test]
474 async fn missing_http_protocol() {
475 let address: Multiaddr = "/dns/localhost/tcp/0".parse().unwrap();
476 test_multiaddr(address).await;
477 }
478}
479
480#[derive(Clone)]
481struct RequestLifetimeLayer<M: MetricsCallbackProvider> {
482 metrics_provider: M,
483}
484
485impl<M: MetricsCallbackProvider, S> Layer<S> for RequestLifetimeLayer<M> {
486 type Service = RequestLifetime<M, S>;
487
488 fn layer(&self, inner: S) -> Self::Service {
489 RequestLifetime {
490 inner,
491 metrics_provider: self.metrics_provider.clone(),
492 path: None,
493 }
494 }
495}
496
497#[derive(Clone)]
498struct RequestLifetime<M: MetricsCallbackProvider, S> {
499 inner: S,
500 metrics_provider: M,
501 path: Option<String>,
502}
503
504impl<M: MetricsCallbackProvider, S, RequestBody> Service<Request<RequestBody>>
505 for RequestLifetime<M, S>
506where
507 S: Service<Request<RequestBody>>,
508{
509 type Response = S::Response;
510 type Error = S::Error;
511 type Future = S::Future;
512
513 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
514 self.inner.poll_ready(cx)
515 }
516
517 fn call(&mut self, request: Request<RequestBody>) -> Self::Future {
518 if self.path.is_none() {
519 let path = request.uri().path().to_string();
520 self.metrics_provider.on_start(&path);
521 self.path = Some(path);
522 }
523 self.inner.call(request)
524 }
525}
526
527impl<M: MetricsCallbackProvider, S> Drop for RequestLifetime<M, S> {
528 fn drop(&mut self) {
529 if let Some(path) = &self.path {
530 self.metrics_provider.on_drop(path)
531 }
532 }
533}