cronback_lib/
service.rs

1use std::convert::Infallible;
2use std::error::Error;
3use std::net::SocketAddr;
4use std::path::Path;
5use std::sync::Arc;
6use std::time::Duration;
7
8use futures::Stream;
9use hyper::{Body, Request, Response};
10use proto::FILE_DESCRIPTOR_SET;
11use tokio::io::{AsyncRead, AsyncWrite};
12use tokio::net::UnixListener;
13use tokio_stream::wrappers::UnixListenerStream;
14use tonic::body::BoxBody;
15use tonic::transport::server::{Connected, TcpIncoming};
16use tonic::transport::{NamedService, Server};
17use tonic_reflection::server::Builder;
18use tower::Service;
19use tower_http::trace::{MakeSpan, TraceLayer};
20use tracing::{error, error_span, info, Id, Span};
21
22use crate::config::{Config, ConfigLoader};
23use crate::consts::{PARENT_SPAN_HEADER, PROJECT_ID_HEADER, REQUEST_ID_HEADER};
24use crate::rpc_middleware::CronbackRpcMiddleware;
25use crate::shutdown::Shutdown;
26
27#[derive(Clone)]
28pub struct ServiceContext {
29    name: String,
30    config_loader: Arc<ConfigLoader>,
31    shutdown: Shutdown,
32}
33
34impl ServiceContext {
35    pub fn new(
36        name: String,
37        config_loader: Arc<ConfigLoader>,
38        shutdown: Shutdown,
39    ) -> Self {
40        Self {
41            name,
42            config_loader,
43            shutdown,
44        }
45    }
46
47    pub fn service_name(&self) -> &str {
48        &self.name
49    }
50
51    pub fn get_config(&self) -> Config {
52        self.config_loader.load().unwrap()
53    }
54
55    pub fn config_loader(&self) -> Arc<ConfigLoader> {
56        self.config_loader.clone()
57    }
58
59    pub fn load_config(&self) -> Config {
60        self.config_loader.load().unwrap()
61    }
62
63    /// Awaits the shutdown signal
64    pub async fn recv_shutdown_signal(&mut self) {
65        self.shutdown.recv().await
66    }
67
68    /// Causes all listeners to start the shutdown sequence.
69    pub fn broadcast_shutdown(&mut self) {
70        self.shutdown.broadcast_shutdown()
71    }
72}
73
74#[derive(Clone, Debug)]
75struct GrpcMakeSpan {
76    service_name: String,
77}
78impl GrpcMakeSpan {
79    fn new(service_name: String) -> Self {
80        Self { service_name }
81    }
82}
83
84impl<B> MakeSpan<B> for GrpcMakeSpan {
85    fn make_span(&mut self, request: &Request<B>) -> Span {
86        let request_id = request
87            .headers()
88            .get(REQUEST_ID_HEADER)
89            .map(|v| v.to_str().unwrap().to_owned());
90
91        let parent_span = request
92            .headers()
93            .get(PARENT_SPAN_HEADER)
94            .map(|v| v.to_str().unwrap().to_owned());
95
96        let project_id = request
97            .headers()
98            .get(PROJECT_ID_HEADER)
99            .map(|v| v.to_str().unwrap().to_owned());
100
101        let span = error_span!(
102            "grpc_request",
103             service = %self.service_name,
104             request_id = %request_id.unwrap_or_default(),
105             project_id = %project_id.unwrap_or_default(),
106             method = %request.method(),
107             uri = %request.uri(),
108             version = ?request.version(),
109        );
110
111        if let Some(parent_span) = parent_span {
112            let id = Id::from_u64(parent_span.parse().unwrap());
113            span.follows_from(id);
114        }
115        span
116    }
117}
118
119#[tracing::instrument(skip_all, fields(service = context.service_name()))]
120pub async fn grpc_serve_tcp<S>(
121    context: &mut ServiceContext,
122    addr: SocketAddr,
123    svc: S,
124    timeout: u64,
125) where
126    S: Service<Request<Body>, Response = Response<BoxBody>, Error = Infallible>
127        + NamedService
128        + Clone
129        + Send
130        + 'static,
131    S::Future: Send + 'static,
132{
133    info!("Starting '{}' on {:?}", context.service_name(), addr);
134    match TcpIncoming::new(addr, true, None) {
135        | Ok(incoming) => {
136            grpc_serve_incoming(context, svc, incoming, timeout).await
137        }
138        | Err(e) => {
139            error!(
140                "RPC service '{}' couldn't bind on address '{addr}', system \
141                 will shutdown: {e}",
142                context.service_name()
143            );
144            context.broadcast_shutdown();
145        }
146    };
147}
148
149#[tracing::instrument(skip_all, fields(service = context.service_name()))]
150pub async fn grpc_serve_unix<S, K>(
151    context: &mut ServiceContext,
152    socket: K,
153    svc: S,
154    timeout: u64,
155) where
156    S: Service<Request<Body>, Response = Response<BoxBody>, Error = Infallible>
157        + NamedService
158        + Clone
159        + Send
160        + 'static,
161    S::Future: Send + 'static,
162    K: AsRef<Path>,
163{
164    info!(
165        "Starting '{}' on {:?}",
166        context.service_name(),
167        socket.as_ref()
168    );
169    let uds = UnixListener::bind(socket).unwrap();
170    let stream = UnixListenerStream::new(uds);
171    grpc_serve_incoming(context, svc, stream, timeout).await
172}
173
174#[tracing::instrument(skip_all, fields(service = context.service_name()))]
175async fn grpc_serve_incoming<S, K, IO, IE>(
176    context: &mut ServiceContext,
177    svc: S,
178    incoming: K,
179    timeout: u64,
180) where
181    S: Service<Request<Body>, Response = Response<BoxBody>, Error = Infallible>
182        + NamedService
183        + Clone
184        + Send
185        + 'static,
186    S::Future: Send + 'static,
187    K: Stream<Item = Result<IO, IE>>,
188    IO: AsyncRead + AsyncWrite + Connected + Unpin + Send + 'static,
189    IO::ConnectInfo: Clone + Send + Sync + 'static,
190    IE: Into<Box<dyn Error + Send + Sync>>,
191{
192    let svc_name = context.service_name().to_owned();
193    // The stack of middleware that our service will be wrapped in
194    let cronback_middleware = tower::ServiceBuilder::new()
195        // Apply our own middleware
196        .layer(
197            TraceLayer::new_for_grpc()
198                .make_span_with(GrpcMakeSpan::new(svc_name)),
199        )
200        .layer(CronbackRpcMiddleware::new(context.service_name()))
201        .into_inner();
202
203    let reflection_service = Builder::configure()
204        .register_encoded_file_descriptor_set(FILE_DESCRIPTOR_SET)
205        .build()
206        .unwrap();
207
208    // grpc Server
209    if let Err(e) = Server::builder()
210        .timeout(Duration::from_secs(timeout))
211        .layer(cronback_middleware)
212        .add_service(reflection_service)
213        .add_service(svc)
214        .serve_with_incoming_shutdown(incoming, context.recv_shutdown_signal())
215        .await
216    {
217        error!(
218            "RPC service '{}' failed to start and will trigger system \
219             shutdown: {e}",
220            context.service_name()
221        );
222        context.broadcast_shutdown()
223    } else {
224        info!("Service '{}' terminated", context.service_name());
225    }
226}