cronback_lib/
rpc_middleware.rs

1use std::task::{Context, Poll};
2use std::time::Instant;
3
4use hyper::Body;
5use metrics::{histogram, increment_counter};
6use tonic::body::BoxBody;
7use tower::{Layer, Service};
8
9use crate::consts::{PROJECT_ID_HEADER, REQUEST_ID_HEADER};
10use crate::model::{ModelId, ValidShardedId};
11use crate::types::{ProjectId, RequestId};
12
13#[derive(Debug, Clone, Default)]
14pub struct CronbackRpcMiddleware {
15    /// Sets the label "service" in emitted metrics
16    service_name: String,
17}
18
19impl CronbackRpcMiddleware {
20    pub fn new(service_name: &str) -> CronbackRpcMiddleware {
21        CronbackRpcMiddleware {
22            service_name: service_name.into(),
23        }
24    }
25}
26
27impl<S> Layer<S> for CronbackRpcMiddleware {
28    type Service = InnerMiddleware<S>;
29
30    fn layer(&self, service: S) -> Self::Service {
31        InnerMiddleware::new(&self.service_name, service)
32    }
33}
34
35#[derive(Debug, Clone)]
36pub struct InnerMiddleware<S> {
37    inner: S,
38    service_name: String,
39}
40
41impl<S> InnerMiddleware<S> {
42    pub fn new(service_name: &str, inner: S) -> Self {
43        InnerMiddleware {
44            inner,
45            service_name: service_name.to_owned(),
46        }
47    }
48}
49
50impl<S> Service<hyper::Request<Body>> for InnerMiddleware<S>
51where
52    S: Service<hyper::Request<Body>, Response = hyper::Response<BoxBody>>
53        + Clone
54        + Send
55        + 'static,
56    S::Future: Send + 'static,
57{
58    type Error = S::Error;
59    type Future = futures::future::BoxFuture<
60        'static,
61        Result<Self::Response, Self::Error>,
62    >;
63    type Response = S::Response;
64
65    fn poll_ready(
66        &mut self,
67        cx: &mut Context<'_>,
68    ) -> Poll<Result<(), Self::Error>> {
69        self.inner.poll_ready(cx)
70    }
71
72    fn call(&mut self, mut req: hyper::Request<Body>) -> Self::Future {
73        // This is necessary because tonic internally uses
74        // `tower::buffer::Buffer`. See https://github.com/tower-rs/tower/issues/547#issuecomment-767629149
75        // for details on why this is necessary
76        let clone = self.inner.clone();
77        let mut inner = std::mem::replace(&mut self.inner, clone);
78
79        // Do we have a x-cronback-request-id header? Only used in grpc
80        // services. The api-server will set the request-id header with
81        // a random value.
82        if let Some(cronback_request_id) = req.headers().get(REQUEST_ID_HEADER)
83        {
84            // If so, set the request id to the value of the header
85            let cronback_request_id = cronback_request_id.to_str().unwrap();
86            let cronback_request_id =
87                RequestId::from(cronback_request_id.to_owned());
88            req.extensions_mut().insert(cronback_request_id);
89        }
90
91        // Do we have a x-cronback-project-id header?
92        // If project-id is set, it must be valid. We store the result in
93        // extensions.
94        if let Some(project_id) = req.headers().get(PROJECT_ID_HEADER) {
95            // If so, set the project id to the value of the header
96            let project_id = project_id.to_str().unwrap();
97            let maybe_project_id =
98                ProjectId::from(project_id.to_owned()).validated();
99            req.extensions_mut().insert(maybe_project_id);
100        }
101
102        // Removes the leading '/' in the path.
103        let endpoint = req.uri().path()[1..].to_owned();
104        let service_name = self.service_name.clone();
105        let start = Instant::now();
106        increment_counter!(
107            "rpc.requests_total",
108            "service" => service_name.clone(),
109            "endpoint" => endpoint.clone()
110        );
111        Box::pin(async move {
112            let mut response = inner.call(req).await?;
113            let latency_s = (Instant::now() - start).as_secs_f64();
114            histogram!(
115                "rpc.duration_seconds",
116                latency_s,
117                "service" => service_name.clone(),
118                "endpoint" => endpoint.clone(),
119            );
120
121            // Inject request_id into response headers
122            if let Some(request_id) =
123                response.extensions().get::<RequestId>().cloned()
124            {
125                response.headers_mut().insert(
126                    REQUEST_ID_HEADER,
127                    request_id.to_string().parse().unwrap(),
128                );
129            }
130
131            // Inject project_id into response headers
132            if let Some(project_id) = response
133                .extensions()
134                .get::<ValidShardedId<ProjectId>>()
135                .cloned()
136            {
137                response.headers_mut().insert(
138                    PROJECT_ID_HEADER,
139                    project_id.to_string().parse().unwrap(),
140                );
141            }
142
143            Ok(response)
144        })
145    }
146}