1use std::convert::Infallible;
2use std::error::Error;
3use std::net::SocketAddr;
4use std::path::Path;
5use std::sync::Arc;
6use std::time::Duration;
7
8use futures::Stream;
9use hyper::{Body, Request, Response};
10use proto::FILE_DESCRIPTOR_SET;
11use tokio::io::{AsyncRead, AsyncWrite};
12use tokio::net::UnixListener;
13use tokio_stream::wrappers::UnixListenerStream;
14use tonic::body::BoxBody;
15use tonic::transport::server::{Connected, TcpIncoming};
16use tonic::transport::{NamedService, Server};
17use tonic_reflection::server::Builder;
18use tower::Service;
19use tower_http::trace::{MakeSpan, TraceLayer};
20use tracing::{error, error_span, info, Id, Span};
21
22use crate::config::{Config, ConfigLoader};
23use crate::consts::{PARENT_SPAN_HEADER, PROJECT_ID_HEADER, REQUEST_ID_HEADER};
24use crate::rpc_middleware::CronbackRpcMiddleware;
25use crate::shutdown::Shutdown;
26
27#[derive(Clone)]
28pub struct ServiceContext {
29 name: String,
30 config_loader: Arc<ConfigLoader>,
31 shutdown: Shutdown,
32}
33
34impl ServiceContext {
35 pub fn new(
36 name: String,
37 config_loader: Arc<ConfigLoader>,
38 shutdown: Shutdown,
39 ) -> Self {
40 Self {
41 name,
42 config_loader,
43 shutdown,
44 }
45 }
46
47 pub fn service_name(&self) -> &str {
48 &self.name
49 }
50
51 pub fn get_config(&self) -> Config {
52 self.config_loader.load().unwrap()
53 }
54
55 pub fn config_loader(&self) -> Arc<ConfigLoader> {
56 self.config_loader.clone()
57 }
58
59 pub fn load_config(&self) -> Config {
60 self.config_loader.load().unwrap()
61 }
62
63 pub async fn recv_shutdown_signal(&mut self) {
65 self.shutdown.recv().await
66 }
67
68 pub fn broadcast_shutdown(&mut self) {
70 self.shutdown.broadcast_shutdown()
71 }
72}
73
74#[derive(Clone, Debug)]
75struct GrpcMakeSpan {
76 service_name: String,
77}
78impl GrpcMakeSpan {
79 fn new(service_name: String) -> Self {
80 Self { service_name }
81 }
82}
83
84impl<B> MakeSpan<B> for GrpcMakeSpan {
85 fn make_span(&mut self, request: &Request<B>) -> Span {
86 let request_id = request
87 .headers()
88 .get(REQUEST_ID_HEADER)
89 .map(|v| v.to_str().unwrap().to_owned());
90
91 let parent_span = request
92 .headers()
93 .get(PARENT_SPAN_HEADER)
94 .map(|v| v.to_str().unwrap().to_owned());
95
96 let project_id = request
97 .headers()
98 .get(PROJECT_ID_HEADER)
99 .map(|v| v.to_str().unwrap().to_owned());
100
101 let span = error_span!(
102 "grpc_request",
103 service = %self.service_name,
104 request_id = %request_id.unwrap_or_default(),
105 project_id = %project_id.unwrap_or_default(),
106 method = %request.method(),
107 uri = %request.uri(),
108 version = ?request.version(),
109 );
110
111 if let Some(parent_span) = parent_span {
112 let id = Id::from_u64(parent_span.parse().unwrap());
113 span.follows_from(id);
114 }
115 span
116 }
117}
118
119#[tracing::instrument(skip_all, fields(service = context.service_name()))]
120pub async fn grpc_serve_tcp<S>(
121 context: &mut ServiceContext,
122 addr: SocketAddr,
123 svc: S,
124 timeout: u64,
125) where
126 S: Service<Request<Body>, Response = Response<BoxBody>, Error = Infallible>
127 + NamedService
128 + Clone
129 + Send
130 + 'static,
131 S::Future: Send + 'static,
132{
133 info!("Starting '{}' on {:?}", context.service_name(), addr);
134 match TcpIncoming::new(addr, true, None) {
135 | Ok(incoming) => {
136 grpc_serve_incoming(context, svc, incoming, timeout).await
137 }
138 | Err(e) => {
139 error!(
140 "RPC service '{}' couldn't bind on address '{addr}', system \
141 will shutdown: {e}",
142 context.service_name()
143 );
144 context.broadcast_shutdown();
145 }
146 };
147}
148
149#[tracing::instrument(skip_all, fields(service = context.service_name()))]
150pub async fn grpc_serve_unix<S, K>(
151 context: &mut ServiceContext,
152 socket: K,
153 svc: S,
154 timeout: u64,
155) where
156 S: Service<Request<Body>, Response = Response<BoxBody>, Error = Infallible>
157 + NamedService
158 + Clone
159 + Send
160 + 'static,
161 S::Future: Send + 'static,
162 K: AsRef<Path>,
163{
164 info!(
165 "Starting '{}' on {:?}",
166 context.service_name(),
167 socket.as_ref()
168 );
169 let uds = UnixListener::bind(socket).unwrap();
170 let stream = UnixListenerStream::new(uds);
171 grpc_serve_incoming(context, svc, stream, timeout).await
172}
173
174#[tracing::instrument(skip_all, fields(service = context.service_name()))]
175async fn grpc_serve_incoming<S, K, IO, IE>(
176 context: &mut ServiceContext,
177 svc: S,
178 incoming: K,
179 timeout: u64,
180) where
181 S: Service<Request<Body>, Response = Response<BoxBody>, Error = Infallible>
182 + NamedService
183 + Clone
184 + Send
185 + 'static,
186 S::Future: Send + 'static,
187 K: Stream<Item = Result<IO, IE>>,
188 IO: AsyncRead + AsyncWrite + Connected + Unpin + Send + 'static,
189 IO::ConnectInfo: Clone + Send + Sync + 'static,
190 IE: Into<Box<dyn Error + Send + Sync>>,
191{
192 let svc_name = context.service_name().to_owned();
193 let cronback_middleware = tower::ServiceBuilder::new()
195 .layer(
197 TraceLayer::new_for_grpc()
198 .make_span_with(GrpcMakeSpan::new(svc_name)),
199 )
200 .layer(CronbackRpcMiddleware::new(context.service_name()))
201 .into_inner();
202
203 let reflection_service = Builder::configure()
204 .register_encoded_file_descriptor_set(FILE_DESCRIPTOR_SET)
205 .build()
206 .unwrap();
207
208 if let Err(e) = Server::builder()
210 .timeout(Duration::from_secs(timeout))
211 .layer(cronback_middleware)
212 .add_service(reflection_service)
213 .add_service(svc)
214 .serve_with_incoming_shutdown(incoming, context.recv_shutdown_signal())
215 .await
216 {
217 error!(
218 "RPC service '{}' failed to start and will trigger system \
219 shutdown: {e}",
220 context.service_name()
221 );
222 context.broadcast_shutdown()
223 } else {
224 info!("Service '{}' terminated", context.service_name());
225 }
226}