gcloud-sdk 0.30.0

Async Google gRPC/REST APIs and the client implementation hiding complexity of GCP authentication based on Tonic middleware and Reqwest.
Documentation
use crate::token_source::auth_token_generator::GoogleAuthTokenGenerator;
use chrono::Utc;
use futures::{Future, TryFutureExt};
use std::pin::Pin;
use std::sync::Arc;
use std::task::{Context, Poll};
use tonic::client::GrpcService;
use tower::Service;
use tower_layer::Layer;
use tracing::*;

#[derive(Clone)]
pub struct GoogleAuthMiddlewareService<T>
where
    T: Clone,
{
    google_service: Option<T>,
    token_generator: Arc<GoogleAuthTokenGenerator>,
    cloud_resource_prefix: Option<String>,
    user_agent: String,
    x_goog_api_client: String,
    additional_headers: hyper::header::HeaderMap,
}

impl<T> GoogleAuthMiddlewareService<T>
where
    T: Clone,
{
    pub fn new(
        service: T,
        token_generator: Arc<GoogleAuthTokenGenerator>,
        cloud_resource_prefix: Option<String>,
    ) -> GoogleAuthMiddlewareService<T> {
        GoogleAuthMiddlewareService {
            google_service: Some(service),
            token_generator,
            cloud_resource_prefix,
            user_agent: format!("gcloud-sdk-rs/{}", env!("CARGO_PKG_VERSION")),
            x_goog_api_client: format!("gcloud-sdk-rs/{}", env!("CARGO_PKG_VERSION")),
            additional_headers: hyper::header::HeaderMap::new(),
        }
    }

    pub fn set_user_agent(&mut self, user_agent: String) {
        self.user_agent = user_agent;
    }

    pub fn set_x_goog_api_client(&mut self, x_goog_api_client: String) {
        self.x_goog_api_client = x_goog_api_client;
    }

    pub fn append_user_agent(&mut self, user_agent: String) {
        self.user_agent = format!("{} {}", self.user_agent, user_agent);
    }

    pub fn append_x_goog_api_client(&mut self, x_goog_api_client: String) {
        self.x_goog_api_client = format!("{} {}", self.x_goog_api_client, x_goog_api_client);
    }

    pub fn set_additional_headers(&mut self, additional_headers: hyper::HeaderMap) {
        self.additional_headers = additional_headers;
    }
}

impl<T, RequestBody> Service<hyper::Request<RequestBody>> for GoogleAuthMiddlewareService<T>
where
    T: GrpcService<RequestBody> + Send + Clone + 'static,
    T::Future: 'static + Send,
    RequestBody: 'static + Send,
    T::ResponseBody: 'static + Send,
    T::Error: 'static + Send,
{
    type Response = hyper::Response<T::ResponseBody>;
    type Error = Box<dyn std::error::Error + Send + Sync + 'static>;
    type Future =
        Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send + 'static>>;

    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
        if let Some(ref mut google_service) = self.google_service.as_mut() {
            google_service.poll_ready(cx).map_err(|e| e.into())
        } else {
            Poll::Pending
        }
    }

    fn call(&mut self, mut req: hyper::Request<RequestBody>) -> Self::Future {
        let generator = self.token_generator.clone();
        let cloud_resource_prefix = self.cloud_resource_prefix.clone();
        let user_agent = self.user_agent.clone();
        let x_goog_api_client = self.x_goog_api_client.clone();
        let additional_headers = self.additional_headers.clone();

        if let Some(mut google_service) = self.google_service.take() {
            self.google_service = Some(google_service.clone());
            Box::pin(async move {
                let begin_time = Utc::now();
                let token = generator.create_token().await.map_err(Box::new)?;
                let token_generated_time = Utc::now();
                let headers = req.headers_mut();
                headers.insert("authorization", token.header_value().parse()?);
                if let Some(cloud_resource_prefix_value) = cloud_resource_prefix {
                    headers.insert(
                        "google-cloud-resource-prefix",
                        cloud_resource_prefix_value.parse()?,
                    );
                }
                headers.insert(hyper::header::USER_AGENT, user_agent.parse()?);
                headers.insert("x-goog-api-client", x_goog_api_client.parse()?);

                for (maybe_k, v) in additional_headers.into_iter() {
                    if let Some(k) = maybe_k {
                        headers.insert(k, v);
                    }
                }

                let req_uri_str = req.uri().to_string();
                google_service
                    .call(req)
                    .map_ok(|x| {
                        let finished_time = Utc::now();
                        debug!(
                            "OK: {} took {}ms (incl. token gen: {}ms)",
                            req_uri_str,
                            finished_time
                                .signed_duration_since(begin_time)
                                .num_milliseconds(),
                            token_generated_time
                                .signed_duration_since(begin_time)
                                .num_milliseconds()
                        );
                        x
                    })
                    .await
                    .map_err(|e| {
                        let finished_time = Utc::now();
                        error!(
                            "Err: {} took {}ms (incl. token gen: {}ms)",
                            req_uri_str,
                            finished_time
                                .signed_duration_since(begin_time)
                                .num_milliseconds(),
                            token_generated_time
                                .signed_duration_since(begin_time)
                                .num_milliseconds()
                        );
                        e.into()
                    })
            })
        } else {
            panic!("Should never happen, system error");
        }
    }
}

pub struct GoogleAuthMiddlewareLayer {
    pub token_generator: Arc<GoogleAuthTokenGenerator>,
    pub cloud_resource_prefix: Option<String>,
    pub user_agent: String,
    pub x_goog_api_client: String,
    pub additional_headers: hyper::header::HeaderMap,
}

impl GoogleAuthMiddlewareLayer {
    pub fn new(
        token_generator: GoogleAuthTokenGenerator,
        cloud_resource_prefix: Option<String>,
    ) -> Self {
        GoogleAuthMiddlewareLayer {
            token_generator: Arc::new(token_generator),
            cloud_resource_prefix,
            user_agent: format!("gcloud-sdk-rs/{}", env!("CARGO_PKG_VERSION")),
            x_goog_api_client: format!("gcloud-sdk-rs/{}", env!("CARGO_PKG_VERSION")),
            additional_headers: hyper::header::HeaderMap::new(),
        }
    }

    pub fn amend_user_agent(mut self, user_agent: String) -> Self {
        self.user_agent = format!("{} {}", self.user_agent, user_agent);
        self
    }

    pub fn amend_x_goog_api_client(mut self, x_goog_api_client: String) -> Self {
        self.x_goog_api_client = format!("{} {}", self.x_goog_api_client, x_goog_api_client);
        self
    }

    pub fn set_additional_headers(&mut self, additional_headers: hyper::HeaderMap) {
        self.additional_headers = additional_headers;
    }
}

impl<S> Layer<S> for GoogleAuthMiddlewareLayer
where
    S: Clone,
{
    type Service = GoogleAuthMiddlewareService<S>;

    fn layer(&self, service: S) -> GoogleAuthMiddlewareService<S> {
        let mut middleware_service = GoogleAuthMiddlewareService::new(
            service,
            self.token_generator.clone(),
            self.cloud_resource_prefix.clone(),
        );
        middleware_service.set_user_agent(self.user_agent.clone());
        middleware_service.set_x_goog_api_client(self.x_goog_api_client.clone());
        middleware_service.set_additional_headers(self.additional_headers.clone());
        middleware_service
    }
}

#[cfg(test)]
mod tests {
    use super::*;
    use crate::token_source::{Source, Token, TokenSourceType};
    use async_trait::async_trait;
    use hyper::{Request, Response};
    use secret_vault_value::SecretValue;
    use std::convert::Infallible;

    struct DummySource;

    #[async_trait]
    impl Source for DummySource {
        async fn token(&self) -> crate::error::Result<Token> {
            Ok(Token {
                token_type: "Bearer".to_string(),
                token: SecretValue::from("dummy-token"),
                expiry: Utc::now() + chrono::Duration::hours(1),
            })
        }
    }

    #[derive(Clone)]
    struct DummyService {
        tx: Arc<tokio::sync::mpsc::Sender<Request<String>>>,
    }

    impl Service<Request<String>> for DummyService {
        type Response = Response<String>;
        type Error = Infallible;
        type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;

        fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
            Poll::Ready(Ok(()))
        }

        fn call(&mut self, req: Request<String>) -> Self::Future {
            let tx = self.tx.clone();
            Box::pin(async move {
                tx.send(req).await.unwrap();
                Ok(Response::builder()
                    .status(200)
                    .body("".to_string())
                    .unwrap())
            })
        }
    }

    #[tokio::test]
    async fn test_headers_presence() {
        let token_generator = GoogleAuthTokenGenerator::new(
            TokenSourceType::ExternalSource(Box::new(DummySource)),
            vec![],
        )
        .await
        .unwrap();

        let (tx, mut rx) = tokio::sync::mpsc::channel(1);
        let dummy_service = DummyService { tx: Arc::new(tx) };
        let mut service =
            GoogleAuthMiddlewareService::new(dummy_service, Arc::new(token_generator), None);

        let req = Request::builder()
            .uri("http://example.com")
            .body("".to_string())
            .unwrap();

        tower::Service::call(&mut service, req).await.unwrap();

        let captured_req = rx.recv().await.unwrap();
        let expected_default = format!("gcloud-sdk-rs/{}", env!("CARGO_PKG_VERSION"));
        assert_eq!(
            captured_req
                .headers()
                .get(hyper::header::USER_AGENT)
                .unwrap(),
            expected_default.as_str()
        );
        assert_eq!(
            captured_req.headers().get("x-goog-api-client").unwrap(),
            expected_default.as_str()
        );
        assert_eq!(
            captured_req.headers().get("authorization").unwrap(),
            "Bearer dummy-token"
        );
    }

    #[tokio::test]
    async fn test_headers_amend() {
        let token_generator = GoogleAuthTokenGenerator::new(
            TokenSourceType::ExternalSource(Box::new(DummySource)),
            vec![],
        )
        .await
        .unwrap();

        let (tx, mut rx) = tokio::sync::mpsc::channel(1);
        let dummy_service = DummyService { tx: Arc::new(tx) };

        let layer = GoogleAuthMiddlewareLayer::new(token_generator, None)
            .amend_user_agent("extra-ua".to_string())
            .amend_x_goog_api_client("extra-client".to_string());

        let mut service = layer.layer(dummy_service);

        let req = Request::builder()
            .uri("http://example.com")
            .body("".to_string())
            .unwrap();

        tower::Service::call(&mut service, req).await.unwrap();

        let captured_req = rx.recv().await.unwrap();
        let expected_ua = format!("gcloud-sdk-rs/{} extra-ua", env!("CARGO_PKG_VERSION"));
        let expected_client = format!("gcloud-sdk-rs/{} extra-client", env!("CARGO_PKG_VERSION"));

        assert_eq!(
            captured_req
                .headers()
                .get(hyper::header::USER_AGENT)
                .unwrap(),
            expected_ua.as_str()
        );
        assert_eq!(
            captured_req.headers().get("x-goog-api-client").unwrap(),
            expected_client.as_str()
        );
    }

    #[tokio::test]
    async fn test_additional_headers() {
        let token_generator = GoogleAuthTokenGenerator::new(
            TokenSourceType::ExternalSource(Box::new(DummySource)),
            vec![],
        )
        .await
        .unwrap();

        let (tx, mut rx) = tokio::sync::mpsc::channel(1);
        let dummy_service = DummyService { tx: Arc::new(tx) };
        let mut service =
            GoogleAuthMiddlewareService::new(dummy_service, Arc::new(token_generator), None);
        let mut test_headers = hyper::HeaderMap::new();
        test_headers.insert("x-test-header", "test-value".parse().unwrap());
        service.set_additional_headers(test_headers);

        let req = Request::builder()
            .uri("http://example.com")
            .body("".to_string())
            .unwrap();

        tower::Service::call(&mut service, req).await.unwrap();

        let captured_req = rx.recv().await.unwrap();
        assert_eq!(
            captured_req.headers().get("x-test-header").unwrap(),
            "test-value"
        );
    }
}