Skip to main content

gcloud_sdk/
middleware.rs

1use crate::token_source::auth_token_generator::GoogleAuthTokenGenerator;
2use chrono::Utc;
3use futures::{Future, TryFutureExt};
4use std::pin::Pin;
5use std::sync::Arc;
6use std::task::{Context, Poll};
7use tonic::client::GrpcService;
8use tower::Service;
9use tower_layer::Layer;
10use tracing::*;
11
12#[derive(Clone)]
13pub struct GoogleAuthMiddlewareService<T>
14where
15    T: Clone,
16{
17    google_service: Option<T>,
18    token_generator: Arc<GoogleAuthTokenGenerator>,
19    cloud_resource_prefix: Option<String>,
20    user_agent: String,
21    x_goog_api_client: String,
22    additional_headers: hyper::header::HeaderMap,
23}
24
25impl<T> GoogleAuthMiddlewareService<T>
26where
27    T: Clone,
28{
29    pub fn new(
30        service: T,
31        token_generator: Arc<GoogleAuthTokenGenerator>,
32        cloud_resource_prefix: Option<String>,
33    ) -> GoogleAuthMiddlewareService<T> {
34        GoogleAuthMiddlewareService {
35            google_service: Some(service),
36            token_generator,
37            cloud_resource_prefix,
38            user_agent: format!("gcloud-sdk-rs/{}", env!("CARGO_PKG_VERSION")),
39            x_goog_api_client: format!("gcloud-sdk-rs/{}", env!("CARGO_PKG_VERSION")),
40            additional_headers: hyper::header::HeaderMap::new(),
41        }
42    }
43
44    pub fn set_user_agent(&mut self, user_agent: String) {
45        self.user_agent = user_agent;
46    }
47
48    pub fn set_x_goog_api_client(&mut self, x_goog_api_client: String) {
49        self.x_goog_api_client = x_goog_api_client;
50    }
51
52    pub fn append_user_agent(&mut self, user_agent: String) {
53        self.user_agent = format!("{} {}", self.user_agent, user_agent);
54    }
55
56    pub fn append_x_goog_api_client(&mut self, x_goog_api_client: String) {
57        self.x_goog_api_client = format!("{} {}", self.x_goog_api_client, x_goog_api_client);
58    }
59
60    pub fn set_additional_headers(&mut self, additional_headers: hyper::HeaderMap) {
61        self.additional_headers = additional_headers;
62    }
63}
64
65impl<T, RequestBody> Service<hyper::Request<RequestBody>> for GoogleAuthMiddlewareService<T>
66where
67    T: GrpcService<RequestBody> + Send + Clone + 'static,
68    T::Future: 'static + Send,
69    RequestBody: 'static + Send,
70    T::ResponseBody: 'static + Send,
71    T::Error: 'static + Send,
72{
73    type Response = hyper::Response<T::ResponseBody>;
74    type Error = Box<dyn std::error::Error + Send + Sync + 'static>;
75    type Future =
76        Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send + 'static>>;
77
78    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
79        if let Some(ref mut google_service) = self.google_service.as_mut() {
80            google_service.poll_ready(cx).map_err(|e| e.into())
81        } else {
82            Poll::Pending
83        }
84    }
85
86    fn call(&mut self, mut req: hyper::Request<RequestBody>) -> Self::Future {
87        let generator = self.token_generator.clone();
88        let cloud_resource_prefix = self.cloud_resource_prefix.clone();
89        let user_agent = self.user_agent.clone();
90        let x_goog_api_client = self.x_goog_api_client.clone();
91        let additional_headers = self.additional_headers.clone();
92
93        if let Some(mut google_service) = self.google_service.take() {
94            self.google_service = Some(google_service.clone());
95            Box::pin(async move {
96                let begin_time = Utc::now();
97                let token = generator.create_token().await.map_err(Box::new)?;
98                let token_generated_time = Utc::now();
99                let headers = req.headers_mut();
100                headers.insert("authorization", token.header_value().parse()?);
101                if let Some(cloud_resource_prefix_value) = cloud_resource_prefix {
102                    headers.insert(
103                        "google-cloud-resource-prefix",
104                        cloud_resource_prefix_value.parse()?,
105                    );
106                }
107                headers.insert(hyper::header::USER_AGENT, user_agent.parse()?);
108                headers.insert("x-goog-api-client", x_goog_api_client.parse()?);
109
110                for (maybe_k, v) in additional_headers.into_iter() {
111                    if let Some(k) = maybe_k {
112                        headers.insert(k, v);
113                    }
114                }
115
116                let req_uri_str = req.uri().to_string();
117                google_service
118                    .call(req)
119                    .map_ok(|x| {
120                        let finished_time = Utc::now();
121                        debug!(
122                            "OK: {} took {}ms (incl. token gen: {}ms)",
123                            req_uri_str,
124                            finished_time
125                                .signed_duration_since(begin_time)
126                                .num_milliseconds(),
127                            token_generated_time
128                                .signed_duration_since(begin_time)
129                                .num_milliseconds()
130                        );
131                        x
132                    })
133                    .await
134                    .map_err(|e| {
135                        let finished_time = Utc::now();
136                        error!(
137                            "Err: {} took {}ms (incl. token gen: {}ms)",
138                            req_uri_str,
139                            finished_time
140                                .signed_duration_since(begin_time)
141                                .num_milliseconds(),
142                            token_generated_time
143                                .signed_duration_since(begin_time)
144                                .num_milliseconds()
145                        );
146                        e.into()
147                    })
148            })
149        } else {
150            panic!("Should never happen, system error");
151        }
152    }
153}
154
155pub struct GoogleAuthMiddlewareLayer {
156    pub token_generator: Arc<GoogleAuthTokenGenerator>,
157    pub cloud_resource_prefix: Option<String>,
158    pub user_agent: String,
159    pub x_goog_api_client: String,
160    pub additional_headers: hyper::header::HeaderMap,
161}
162
163impl GoogleAuthMiddlewareLayer {
164    pub fn new(
165        token_generator: GoogleAuthTokenGenerator,
166        cloud_resource_prefix: Option<String>,
167    ) -> Self {
168        GoogleAuthMiddlewareLayer {
169            token_generator: Arc::new(token_generator),
170            cloud_resource_prefix,
171            user_agent: format!("gcloud-sdk-rs/{}", env!("CARGO_PKG_VERSION")),
172            x_goog_api_client: format!("gcloud-sdk-rs/{}", env!("CARGO_PKG_VERSION")),
173            additional_headers: hyper::header::HeaderMap::new(),
174        }
175    }
176
177    pub fn amend_user_agent(mut self, user_agent: String) -> Self {
178        self.user_agent = format!("{} {}", self.user_agent, user_agent);
179        self
180    }
181
182    pub fn amend_x_goog_api_client(mut self, x_goog_api_client: String) -> Self {
183        self.x_goog_api_client = format!("{} {}", self.x_goog_api_client, x_goog_api_client);
184        self
185    }
186
187    pub fn set_additional_headers(&mut self, additional_headers: hyper::HeaderMap) {
188        self.additional_headers = additional_headers;
189    }
190}
191
192impl<S> Layer<S> for GoogleAuthMiddlewareLayer
193where
194    S: Clone,
195{
196    type Service = GoogleAuthMiddlewareService<S>;
197
198    fn layer(&self, service: S) -> GoogleAuthMiddlewareService<S> {
199        let mut middleware_service = GoogleAuthMiddlewareService::new(
200            service,
201            self.token_generator.clone(),
202            self.cloud_resource_prefix.clone(),
203        );
204        middleware_service.set_user_agent(self.user_agent.clone());
205        middleware_service.set_x_goog_api_client(self.x_goog_api_client.clone());
206        middleware_service.set_additional_headers(self.additional_headers.clone());
207        middleware_service
208    }
209}
210
211#[cfg(test)]
212mod tests {
213    use super::*;
214    use crate::token_source::{Source, Token, TokenSourceType};
215    use async_trait::async_trait;
216    use hyper::{Request, Response};
217    use secret_vault_value::SecretValue;
218    use std::convert::Infallible;
219
220    struct DummySource;
221
222    #[async_trait]
223    impl Source for DummySource {
224        async fn token(&self) -> crate::error::Result<Token> {
225            Ok(Token {
226                token_type: "Bearer".to_string(),
227                token: SecretValue::from("dummy-token"),
228                expiry: Utc::now() + chrono::Duration::hours(1),
229            })
230        }
231    }
232
233    #[derive(Clone)]
234    struct DummyService {
235        tx: Arc<tokio::sync::mpsc::Sender<Request<String>>>,
236    }
237
238    impl Service<Request<String>> for DummyService {
239        type Response = Response<String>;
240        type Error = Infallible;
241        type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
242
243        fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
244            Poll::Ready(Ok(()))
245        }
246
247        fn call(&mut self, req: Request<String>) -> Self::Future {
248            let tx = self.tx.clone();
249            Box::pin(async move {
250                tx.send(req).await.unwrap();
251                Ok(Response::builder()
252                    .status(200)
253                    .body("".to_string())
254                    .unwrap())
255            })
256        }
257    }
258
259    #[tokio::test]
260    async fn test_headers_presence() {
261        let token_generator = GoogleAuthTokenGenerator::new(
262            TokenSourceType::ExternalSource(Box::new(DummySource)),
263            vec![],
264        )
265        .await
266        .unwrap();
267
268        let (tx, mut rx) = tokio::sync::mpsc::channel(1);
269        let dummy_service = DummyService { tx: Arc::new(tx) };
270        let mut service =
271            GoogleAuthMiddlewareService::new(dummy_service, Arc::new(token_generator), None);
272
273        let req = Request::builder()
274            .uri("http://example.com")
275            .body("".to_string())
276            .unwrap();
277
278        tower::Service::call(&mut service, req).await.unwrap();
279
280        let captured_req = rx.recv().await.unwrap();
281        let expected_default = format!("gcloud-sdk-rs/{}", env!("CARGO_PKG_VERSION"));
282        assert_eq!(
283            captured_req
284                .headers()
285                .get(hyper::header::USER_AGENT)
286                .unwrap(),
287            expected_default.as_str()
288        );
289        assert_eq!(
290            captured_req.headers().get("x-goog-api-client").unwrap(),
291            expected_default.as_str()
292        );
293        assert_eq!(
294            captured_req.headers().get("authorization").unwrap(),
295            "Bearer dummy-token"
296        );
297    }
298
299    #[tokio::test]
300    async fn test_headers_amend() {
301        let token_generator = GoogleAuthTokenGenerator::new(
302            TokenSourceType::ExternalSource(Box::new(DummySource)),
303            vec![],
304        )
305        .await
306        .unwrap();
307
308        let (tx, mut rx) = tokio::sync::mpsc::channel(1);
309        let dummy_service = DummyService { tx: Arc::new(tx) };
310
311        let layer = GoogleAuthMiddlewareLayer::new(token_generator, None)
312            .amend_user_agent("extra-ua".to_string())
313            .amend_x_goog_api_client("extra-client".to_string());
314
315        let mut service = layer.layer(dummy_service);
316
317        let req = Request::builder()
318            .uri("http://example.com")
319            .body("".to_string())
320            .unwrap();
321
322        tower::Service::call(&mut service, req).await.unwrap();
323
324        let captured_req = rx.recv().await.unwrap();
325        let expected_ua = format!("gcloud-sdk-rs/{} extra-ua", env!("CARGO_PKG_VERSION"));
326        let expected_client = format!("gcloud-sdk-rs/{} extra-client", env!("CARGO_PKG_VERSION"));
327
328        assert_eq!(
329            captured_req
330                .headers()
331                .get(hyper::header::USER_AGENT)
332                .unwrap(),
333            expected_ua.as_str()
334        );
335        assert_eq!(
336            captured_req.headers().get("x-goog-api-client").unwrap(),
337            expected_client.as_str()
338        );
339    }
340
341    #[tokio::test]
342    async fn test_additional_headers() {
343        let token_generator = GoogleAuthTokenGenerator::new(
344            TokenSourceType::ExternalSource(Box::new(DummySource)),
345            vec![],
346        )
347        .await
348        .unwrap();
349
350        let (tx, mut rx) = tokio::sync::mpsc::channel(1);
351        let dummy_service = DummyService { tx: Arc::new(tx) };
352        let mut service =
353            GoogleAuthMiddlewareService::new(dummy_service, Arc::new(token_generator), None);
354        let mut test_headers = hyper::HeaderMap::new();
355        test_headers.insert("x-test-header", "test-value".parse().unwrap());
356        service.set_additional_headers(test_headers);
357
358        let req = Request::builder()
359            .uri("http://example.com")
360            .body("".to_string())
361            .unwrap();
362
363        tower::Service::call(&mut service, req).await.unwrap();
364
365        let captured_req = rx.recv().await.unwrap();
366        assert_eq!(
367            captured_req.headers().get("x-test-header").unwrap(),
368            "test-value"
369        );
370    }
371}