1use crate::token_source::auth_token_generator::GoogleAuthTokenGenerator;
2use chrono::Utc;
3use futures::{Future, TryFutureExt};
4use std::pin::Pin;
5use std::sync::Arc;
6use std::task::{Context, Poll};
7use tonic::client::GrpcService;
8use tower::Service;
9use tower_layer::Layer;
10use tracing::*;
11
12#[derive(Clone)]
13pub struct GoogleAuthMiddlewareService<T>
14where
15 T: Clone,
16{
17 google_service: Option<T>,
18 token_generator: Arc<GoogleAuthTokenGenerator>,
19 cloud_resource_prefix: Option<String>,
20 user_agent: String,
21 x_goog_api_client: String,
22 additional_headers: hyper::header::HeaderMap,
23}
24
25impl<T> GoogleAuthMiddlewareService<T>
26where
27 T: Clone,
28{
29 pub fn new(
30 service: T,
31 token_generator: Arc<GoogleAuthTokenGenerator>,
32 cloud_resource_prefix: Option<String>,
33 ) -> GoogleAuthMiddlewareService<T> {
34 GoogleAuthMiddlewareService {
35 google_service: Some(service),
36 token_generator,
37 cloud_resource_prefix,
38 user_agent: format!("gcloud-sdk-rs/{}", env!("CARGO_PKG_VERSION")),
39 x_goog_api_client: format!("gcloud-sdk-rs/{}", env!("CARGO_PKG_VERSION")),
40 additional_headers: hyper::header::HeaderMap::new(),
41 }
42 }
43
44 pub fn set_user_agent(&mut self, user_agent: String) {
45 self.user_agent = user_agent;
46 }
47
48 pub fn set_x_goog_api_client(&mut self, x_goog_api_client: String) {
49 self.x_goog_api_client = x_goog_api_client;
50 }
51
52 pub fn append_user_agent(&mut self, user_agent: String) {
53 self.user_agent = format!("{} {}", self.user_agent, user_agent);
54 }
55
56 pub fn append_x_goog_api_client(&mut self, x_goog_api_client: String) {
57 self.x_goog_api_client = format!("{} {}", self.x_goog_api_client, x_goog_api_client);
58 }
59
60 pub fn set_additional_headers(&mut self, additional_headers: hyper::HeaderMap) {
61 self.additional_headers = additional_headers;
62 }
63}
64
65impl<T, RequestBody> Service<hyper::Request<RequestBody>> for GoogleAuthMiddlewareService<T>
66where
67 T: GrpcService<RequestBody> + Send + Clone + 'static,
68 T::Future: 'static + Send,
69 RequestBody: 'static + Send,
70 T::ResponseBody: 'static + Send,
71 T::Error: 'static + Send,
72{
73 type Response = hyper::Response<T::ResponseBody>;
74 type Error = Box<dyn std::error::Error + Send + Sync + 'static>;
75 type Future =
76 Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send + 'static>>;
77
78 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
79 if let Some(ref mut google_service) = self.google_service.as_mut() {
80 google_service.poll_ready(cx).map_err(|e| e.into())
81 } else {
82 Poll::Pending
83 }
84 }
85
86 fn call(&mut self, mut req: hyper::Request<RequestBody>) -> Self::Future {
87 let generator = self.token_generator.clone();
88 let cloud_resource_prefix = self.cloud_resource_prefix.clone();
89 let user_agent = self.user_agent.clone();
90 let x_goog_api_client = self.x_goog_api_client.clone();
91 let additional_headers = self.additional_headers.clone();
92
93 if let Some(mut google_service) = self.google_service.take() {
94 self.google_service = Some(google_service.clone());
95 Box::pin(async move {
96 let begin_time = Utc::now();
97 let token = generator.create_token().await.map_err(Box::new)?;
98 let token_generated_time = Utc::now();
99 let headers = req.headers_mut();
100 headers.insert("authorization", token.header_value().parse()?);
101 if let Some(cloud_resource_prefix_value) = cloud_resource_prefix {
102 headers.insert(
103 "google-cloud-resource-prefix",
104 cloud_resource_prefix_value.parse()?,
105 );
106 }
107 headers.insert(hyper::header::USER_AGENT, user_agent.parse()?);
108 headers.insert("x-goog-api-client", x_goog_api_client.parse()?);
109
110 for (maybe_k, v) in additional_headers.into_iter() {
111 if let Some(k) = maybe_k {
112 headers.insert(k, v);
113 }
114 }
115
116 let req_uri_str = req.uri().to_string();
117 google_service
118 .call(req)
119 .map_ok(|x| {
120 let finished_time = Utc::now();
121 debug!(
122 "OK: {} took {}ms (incl. token gen: {}ms)",
123 req_uri_str,
124 finished_time
125 .signed_duration_since(begin_time)
126 .num_milliseconds(),
127 token_generated_time
128 .signed_duration_since(begin_time)
129 .num_milliseconds()
130 );
131 x
132 })
133 .await
134 .map_err(|e| {
135 let finished_time = Utc::now();
136 error!(
137 "Err: {} took {}ms (incl. token gen: {}ms)",
138 req_uri_str,
139 finished_time
140 .signed_duration_since(begin_time)
141 .num_milliseconds(),
142 token_generated_time
143 .signed_duration_since(begin_time)
144 .num_milliseconds()
145 );
146 e.into()
147 })
148 })
149 } else {
150 panic!("Should never happen, system error");
151 }
152 }
153}
154
155pub struct GoogleAuthMiddlewareLayer {
156 pub token_generator: Arc<GoogleAuthTokenGenerator>,
157 pub cloud_resource_prefix: Option<String>,
158 pub user_agent: String,
159 pub x_goog_api_client: String,
160 pub additional_headers: hyper::header::HeaderMap,
161}
162
163impl GoogleAuthMiddlewareLayer {
164 pub fn new(
165 token_generator: GoogleAuthTokenGenerator,
166 cloud_resource_prefix: Option<String>,
167 ) -> Self {
168 GoogleAuthMiddlewareLayer {
169 token_generator: Arc::new(token_generator),
170 cloud_resource_prefix,
171 user_agent: format!("gcloud-sdk-rs/{}", env!("CARGO_PKG_VERSION")),
172 x_goog_api_client: format!("gcloud-sdk-rs/{}", env!("CARGO_PKG_VERSION")),
173 additional_headers: hyper::header::HeaderMap::new(),
174 }
175 }
176
177 pub fn amend_user_agent(mut self, user_agent: String) -> Self {
178 self.user_agent = format!("{} {}", self.user_agent, user_agent);
179 self
180 }
181
182 pub fn amend_x_goog_api_client(mut self, x_goog_api_client: String) -> Self {
183 self.x_goog_api_client = format!("{} {}", self.x_goog_api_client, x_goog_api_client);
184 self
185 }
186
187 pub fn set_additional_headers(&mut self, additional_headers: hyper::HeaderMap) {
188 self.additional_headers = additional_headers;
189 }
190}
191
192impl<S> Layer<S> for GoogleAuthMiddlewareLayer
193where
194 S: Clone,
195{
196 type Service = GoogleAuthMiddlewareService<S>;
197
198 fn layer(&self, service: S) -> GoogleAuthMiddlewareService<S> {
199 let mut middleware_service = GoogleAuthMiddlewareService::new(
200 service,
201 self.token_generator.clone(),
202 self.cloud_resource_prefix.clone(),
203 );
204 middleware_service.set_user_agent(self.user_agent.clone());
205 middleware_service.set_x_goog_api_client(self.x_goog_api_client.clone());
206 middleware_service.set_additional_headers(self.additional_headers.clone());
207 middleware_service
208 }
209}
210
211#[cfg(test)]
212mod tests {
213 use super::*;
214 use crate::token_source::{Source, Token, TokenSourceType};
215 use async_trait::async_trait;
216 use hyper::{Request, Response};
217 use secret_vault_value::SecretValue;
218 use std::convert::Infallible;
219
220 struct DummySource;
221
222 #[async_trait]
223 impl Source for DummySource {
224 async fn token(&self) -> crate::error::Result<Token> {
225 Ok(Token {
226 token_type: "Bearer".to_string(),
227 token: SecretValue::from("dummy-token"),
228 expiry: Utc::now() + chrono::Duration::hours(1),
229 })
230 }
231 }
232
233 #[derive(Clone)]
234 struct DummyService {
235 tx: Arc<tokio::sync::mpsc::Sender<Request<String>>>,
236 }
237
238 impl Service<Request<String>> for DummyService {
239 type Response = Response<String>;
240 type Error = Infallible;
241 type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
242
243 fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
244 Poll::Ready(Ok(()))
245 }
246
247 fn call(&mut self, req: Request<String>) -> Self::Future {
248 let tx = self.tx.clone();
249 Box::pin(async move {
250 tx.send(req).await.unwrap();
251 Ok(Response::builder()
252 .status(200)
253 .body("".to_string())
254 .unwrap())
255 })
256 }
257 }
258
259 #[tokio::test]
260 async fn test_headers_presence() {
261 let token_generator = GoogleAuthTokenGenerator::new(
262 TokenSourceType::ExternalSource(Box::new(DummySource)),
263 vec![],
264 )
265 .await
266 .unwrap();
267
268 let (tx, mut rx) = tokio::sync::mpsc::channel(1);
269 let dummy_service = DummyService { tx: Arc::new(tx) };
270 let mut service =
271 GoogleAuthMiddlewareService::new(dummy_service, Arc::new(token_generator), None);
272
273 let req = Request::builder()
274 .uri("http://example.com")
275 .body("".to_string())
276 .unwrap();
277
278 tower::Service::call(&mut service, req).await.unwrap();
279
280 let captured_req = rx.recv().await.unwrap();
281 let expected_default = format!("gcloud-sdk-rs/{}", env!("CARGO_PKG_VERSION"));
282 assert_eq!(
283 captured_req
284 .headers()
285 .get(hyper::header::USER_AGENT)
286 .unwrap(),
287 expected_default.as_str()
288 );
289 assert_eq!(
290 captured_req.headers().get("x-goog-api-client").unwrap(),
291 expected_default.as_str()
292 );
293 assert_eq!(
294 captured_req.headers().get("authorization").unwrap(),
295 "Bearer dummy-token"
296 );
297 }
298
299 #[tokio::test]
300 async fn test_headers_amend() {
301 let token_generator = GoogleAuthTokenGenerator::new(
302 TokenSourceType::ExternalSource(Box::new(DummySource)),
303 vec![],
304 )
305 .await
306 .unwrap();
307
308 let (tx, mut rx) = tokio::sync::mpsc::channel(1);
309 let dummy_service = DummyService { tx: Arc::new(tx) };
310
311 let layer = GoogleAuthMiddlewareLayer::new(token_generator, None)
312 .amend_user_agent("extra-ua".to_string())
313 .amend_x_goog_api_client("extra-client".to_string());
314
315 let mut service = layer.layer(dummy_service);
316
317 let req = Request::builder()
318 .uri("http://example.com")
319 .body("".to_string())
320 .unwrap();
321
322 tower::Service::call(&mut service, req).await.unwrap();
323
324 let captured_req = rx.recv().await.unwrap();
325 let expected_ua = format!("gcloud-sdk-rs/{} extra-ua", env!("CARGO_PKG_VERSION"));
326 let expected_client = format!("gcloud-sdk-rs/{} extra-client", env!("CARGO_PKG_VERSION"));
327
328 assert_eq!(
329 captured_req
330 .headers()
331 .get(hyper::header::USER_AGENT)
332 .unwrap(),
333 expected_ua.as_str()
334 );
335 assert_eq!(
336 captured_req.headers().get("x-goog-api-client").unwrap(),
337 expected_client.as_str()
338 );
339 }
340
341 #[tokio::test]
342 async fn test_additional_headers() {
343 let token_generator = GoogleAuthTokenGenerator::new(
344 TokenSourceType::ExternalSource(Box::new(DummySource)),
345 vec![],
346 )
347 .await
348 .unwrap();
349
350 let (tx, mut rx) = tokio::sync::mpsc::channel(1);
351 let dummy_service = DummyService { tx: Arc::new(tx) };
352 let mut service =
353 GoogleAuthMiddlewareService::new(dummy_service, Arc::new(token_generator), None);
354 let mut test_headers = hyper::HeaderMap::new();
355 test_headers.insert("x-test-header", "test-value".parse().unwrap());
356 service.set_additional_headers(test_headers);
357
358 let req = Request::builder()
359 .uri("http://example.com")
360 .body("".to_string())
361 .unwrap();
362
363 tower::Service::call(&mut service, req).await.unwrap();
364
365 let captured_req = rx.recv().await.unwrap();
366 assert_eq!(
367 captured_req.headers().get("x-test-header").unwrap(),
368 "test-value"
369 );
370 }
371}