cronback_lib/
rpc_middleware.rs1use std::task::{Context, Poll};
2use std::time::Instant;
3
4use hyper::Body;
5use metrics::{histogram, increment_counter};
6use tonic::body::BoxBody;
7use tower::{Layer, Service};
8
9use crate::consts::{PROJECT_ID_HEADER, REQUEST_ID_HEADER};
10use crate::model::{ModelId, ValidShardedId};
11use crate::types::{ProjectId, RequestId};
12
13#[derive(Debug, Clone, Default)]
14pub struct CronbackRpcMiddleware {
15 service_name: String,
17}
18
19impl CronbackRpcMiddleware {
20 pub fn new(service_name: &str) -> CronbackRpcMiddleware {
21 CronbackRpcMiddleware {
22 service_name: service_name.into(),
23 }
24 }
25}
26
27impl<S> Layer<S> for CronbackRpcMiddleware {
28 type Service = InnerMiddleware<S>;
29
30 fn layer(&self, service: S) -> Self::Service {
31 InnerMiddleware::new(&self.service_name, service)
32 }
33}
34
35#[derive(Debug, Clone)]
36pub struct InnerMiddleware<S> {
37 inner: S,
38 service_name: String,
39}
40
41impl<S> InnerMiddleware<S> {
42 pub fn new(service_name: &str, inner: S) -> Self {
43 InnerMiddleware {
44 inner,
45 service_name: service_name.to_owned(),
46 }
47 }
48}
49
50impl<S> Service<hyper::Request<Body>> for InnerMiddleware<S>
51where
52 S: Service<hyper::Request<Body>, Response = hyper::Response<BoxBody>>
53 + Clone
54 + Send
55 + 'static,
56 S::Future: Send + 'static,
57{
58 type Error = S::Error;
59 type Future = futures::future::BoxFuture<
60 'static,
61 Result<Self::Response, Self::Error>,
62 >;
63 type Response = S::Response;
64
65 fn poll_ready(
66 &mut self,
67 cx: &mut Context<'_>,
68 ) -> Poll<Result<(), Self::Error>> {
69 self.inner.poll_ready(cx)
70 }
71
72 fn call(&mut self, mut req: hyper::Request<Body>) -> Self::Future {
73 let clone = self.inner.clone();
77 let mut inner = std::mem::replace(&mut self.inner, clone);
78
79 if let Some(cronback_request_id) = req.headers().get(REQUEST_ID_HEADER)
83 {
84 let cronback_request_id = cronback_request_id.to_str().unwrap();
86 let cronback_request_id =
87 RequestId::from(cronback_request_id.to_owned());
88 req.extensions_mut().insert(cronback_request_id);
89 }
90
91 if let Some(project_id) = req.headers().get(PROJECT_ID_HEADER) {
95 let project_id = project_id.to_str().unwrap();
97 let maybe_project_id =
98 ProjectId::from(project_id.to_owned()).validated();
99 req.extensions_mut().insert(maybe_project_id);
100 }
101
102 let endpoint = req.uri().path()[1..].to_owned();
104 let service_name = self.service_name.clone();
105 let start = Instant::now();
106 increment_counter!(
107 "rpc.requests_total",
108 "service" => service_name.clone(),
109 "endpoint" => endpoint.clone()
110 );
111 Box::pin(async move {
112 let mut response = inner.call(req).await?;
113 let latency_s = (Instant::now() - start).as_secs_f64();
114 histogram!(
115 "rpc.duration_seconds",
116 latency_s,
117 "service" => service_name.clone(),
118 "endpoint" => endpoint.clone(),
119 );
120
121 if let Some(request_id) =
123 response.extensions().get::<RequestId>().cloned()
124 {
125 response.headers_mut().insert(
126 REQUEST_ID_HEADER,
127 request_id.to_string().parse().unwrap(),
128 );
129 }
130
131 if let Some(project_id) = response
133 .extensions()
134 .get::<ValidShardedId<ProjectId>>()
135 .cloned()
136 {
137 response.headers_mut().insert(
138 PROJECT_ID_HEADER,
139 project_id.to_string().parse().unwrap(),
140 );
141 }
142
143 Ok(response)
144 })
145 }
146}