gcs-rsync 0.4.8

rsync support for gcs with higher perf than gsutil rsync
Documentation
use crate::gcp::DeserializedResponse;
use crate::Client;

use super::{Error, TokenResult};
use chrono::{DateTime, Utc};
use serde::{Deserialize, Deserializer, Serialize};
use std::{
    fmt::{Debug, Display},
    path::Path,
};
use urlencoding::encode;

#[derive(Deserialize, Debug, Clone)]
pub struct Token {
    access_token: String,
    #[allow(dead_code)]
    token_type: String,
    #[serde(
        deserialize_with = "from_expires_in",
        rename(deserialize = "expires_in")
    )]
    expiry: DateTime<Utc>,
    #[serde(default)]
    scope: Option<String>,
}

const ONE_SECOND_TO_MICROSECONDS: i64 = 1_000_000;

fn from_expires_in<'de, D>(deserializer: D) -> std::result::Result<DateTime<Utc>, D::Error>
where
    D: Deserializer<'de>,
{
    let expires_in: i64 = Deserialize::deserialize(deserializer)?;
    Ok(Utc::now() + chrono::Duration::microseconds(expires_in * ONE_SECOND_TO_MICROSECONDS))
}

impl Display for Token {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        if self.is_valid() {
            write!(f, "Valid Token expires at {}", self.expiry)
        } else {
            write!(f, "Invalid Token expired at {}", self.expiry)
        }
    }
}

pub type AccessToken = String;

impl Token {
    pub fn access_token(&self) -> AccessToken {
        self.access_token.to_owned()
    }

    pub fn is_valid(&self) -> bool {
        self.expiry - chrono::Duration::microseconds(30 * ONE_SECOND_TO_MICROSECONDS) > Utc::now()
    }

    pub fn with_scope(mut self, scope: String) -> Self {
        self.scope = Some(scope);
        self
    }
}

#[async_trait::async_trait]
pub trait TokenGenerator: Sync + Send {
    async fn get(&self, client: &Client) -> TokenResult<Token>;
}

impl Debug for dyn TokenGenerator {
    fn fmt(&self, _: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        Ok(())
    }
}

#[async_trait::async_trait]
impl TokenGenerator for AuthorizedUserCredentials {
    async fn get(&self, client: &Client) -> TokenResult<Token> {
        let req = self;
        let token: DeserializedResponse<Token> = client
            .client
            .post("https://accounts.google.com/o/oauth2/token")
            .json(&req)
            .send()
            .await
            .map_err(Error::HttpError)?
            .json()
            .await
            .map_err(Error::HttpError)?;
        token
            .into_result()
            .map_err(super::Error::unexpected_api_response::<Token>)
    }
}

#[async_trait::async_trait]
impl TokenGenerator for ServiceAccountCredentials {
    async fn get(&self, client: &Client) -> TokenResult<Token> {
        let now = chrono::Utc::now().timestamp();
        let exp = now + 3600;

        let scope = self.scope.to_owned().ok_or(super::Error::MissingScope)?;

        let claims = Claims {
            iss: self.client_email.as_str(),
            scope: scope.as_str(),
            aud: "https://www.googleapis.com/oauth2/v4/token",
            exp,
            iat: now,
        };
        let header = jsonwebtoken::Header {
            alg: jsonwebtoken::Algorithm::RS256,
            ..Default::default()
        };
        let private_key = jsonwebtoken::EncodingKey::from_rsa_pem(self.private_key.as_bytes())
            .map_err(Error::JWTError)?;
        let jwt = jsonwebtoken::encode(&header, &claims, &private_key).map_err(Error::JWTError)?;
        let form = [
            ("grant_type", "urn:ietf:params:oauth:grant-type:jwt-bearer"),
            ("assertion", &jwt),
        ];

        let token: DeserializedResponse<Token> = client
            .client
            .post("https://www.googleapis.com/oauth2/v4/token")
            .form(&form)
            .send()
            .await
            .map_err(Error::HttpError)?
            .json()
            .await
            .map_err(Error::HttpError)?;
        token
            .into_result()
            .map(|t| t.with_scope(scope))
            .map_err(super::Error::unexpected_api_response::<Token>)
    }
}

#[async_trait::async_trait]
impl TokenGenerator for GoogleMetadataServerCredentials {
    async fn get(&self, client: &Client) -> TokenResult<Token> {
        const DEFAULT_TOKEN_GCP_URI: &str = "http://metadata.google.internal/computeMetadata/v1/instance/service-accounts/default/token";
        let uri = match self.scope {
            None => DEFAULT_TOKEN_GCP_URI.to_owned(),
            Some(ref scope) => format!(
                "{}?{}",
                DEFAULT_TOKEN_GCP_URI,
                encode(format!("scopes={}", scope).as_str())
            ),
        };

        let token: DeserializedResponse<Token> = client
            .client
            .get(uri)
            .header("Metadata-Flavor", "Google")
            .send()
            .await
            .map_err(Error::HttpError)?
            .json()
            .await
            .map_err(Error::HttpError)?;
        token
            .into_result()
            .map_err(super::Error::unexpected_api_response::<Token>)
    }
}

fn from_str<T>(str: &str) -> TokenResult<T>
where
    T: serde::de::DeserializeOwned,
{
    serde_json::from_str(str).map_err(Error::deserialization_error::<T>)
}

async fn from_file<T, U>(file_path: T) -> TokenResult<U>
where
    T: AsRef<Path>,
    U: serde::de::DeserializeOwned,
{
    tokio::fs::read_to_string(file_path.as_ref())
        .await
        .map_err(|err| Error::io_error("error while reading file", file_path.as_ref(), err))
        .and_then(|f| from_str(f.as_str()))
}

async fn default<T>() -> TokenResult<T>
where
    T: serde::de::DeserializeOwned,
{
    let default_path = {
        let key = "GOOGLE_APPLICATION_CREDENTIALS";
        std::env::var(key).map_err(|err| Error::env_var_error(key, err))?
    };
    from_file(default_path).await
}

#[derive(Serialize, Debug)]
struct Claims<'a> {
    iss: &'a str,
    aud: &'a str,
    exp: i64,
    iat: i64,
    scope: &'a str,
}

#[derive(Serialize, Deserialize, Debug, PartialEq, Eq)]
pub struct AuthorizedUserCredentials {
    client_id: String,
    client_secret: String,
    refresh_token: String,
    #[serde(default = "refresh_token")]
    grant_type: String,
}

fn refresh_token() -> String {
    "refresh_token".to_owned()
}

impl AuthorizedUserCredentials {
    pub fn from(s: &str) -> TokenResult<Self> {
        from_str(s)
    }

    pub async fn from_file<T>(file_path: T) -> TokenResult<Self>
    where
        T: AsRef<Path>,
    {
        from_file(file_path).await
    }

    pub async fn default() -> TokenResult<Self> {
        default().await
    }
}

#[derive(Serialize, Deserialize, Debug, PartialEq, Eq)]
pub struct ServiceAccountCredentials {
    r#type: String,
    project_id: String,
    private_key_id: String,
    private_key: String,
    client_email: String,
    client_id: String,
    auth_uri: String,
    token_uri: String,
    auth_provider_x509_cert_url: String,
    client_x509_cert_url: String,
    #[serde(default)]
    scope: Option<String>,
}

impl ServiceAccountCredentials {
    pub fn from(s: &str) -> TokenResult<Self> {
        from_str(s)
    }

    pub async fn from_file<T>(file_path: T) -> TokenResult<Self>
    where
        T: AsRef<Path>,
    {
        from_file(file_path).await
    }

    pub async fn default() -> TokenResult<Self> {
        default().await
    }

    pub fn with_scope(mut self, scope: &str) -> Self {
        self.scope = Some(scope.to_owned());
        self
    }
}

#[derive(Serialize, Deserialize, Debug, PartialEq, Eq)]
pub struct GoogleMetadataServerCredentials {
    scope: Option<String>,
}

impl GoogleMetadataServerCredentials {
    pub fn new() -> TokenResult<Self> {
        Ok(GoogleMetadataServerCredentials { scope: None })
    }
    pub fn with_scope(mut self, scope: &str) -> Self {
        self.scope = Some(scope.to_owned());
        self
    }
}

#[cfg(test)]
mod tests {
    use std::ops::Not;

    use crate::gcp::oauth2::token::*;

    #[test]
    fn token_from_json_test() {
        let raw = r#"{
            "access_token": "access_token",
            "expires_in": 3599,
            "scope": "scope",
            "token_type": "Bearer",
            "id_token": "id_token"
        }"#;

        let actual: Token = serde_json::from_str(raw).unwrap();
        assert_eq!("access_token", actual.access_token);
        assert_eq!("Bearer", actual.token_type);
        assert!(actual.expiry > Utc::now());
    }
    #[test]
    fn token_from_authorized_user_json_test() {
        let actual = super::from_str(
            r#"{
                   "client_id": "client_id",
                   "client_secret": "client_secret",
                   "quota_project_id": "quota_project_id",
                   "refresh_token": "refresh_token",
                   "type": "authorized_user"
            }"#,
        )
        .unwrap();
        let au = AuthorizedUserCredentials {
            client_id: "client_id".to_owned(),
            client_secret: "client_secret".to_owned(),
            refresh_token: "refresh_token".to_owned(),
            grant_type: "refresh_token".to_owned(),
        };

        assert_eq!(au, actual);
    }

    #[test]
    fn token_from_service_account_json_test() {
        let actual = super::from_str(
            r#"{
                "type": "service_account",
                "project_id": "project_id",
                "private_key_id": "private_key_id",
                "private_key": "private_key",
                "client_email": "client_email",
                "client_id": "client_id",
                "auth_uri": "auth_uri",
                "token_uri": "token_uri",
                "auth_provider_x509_cert_url": "auth_provider_x509_cert_url",
                "client_x509_cert_url": "client_x509_cert_url"
            }"#,
        )
        .unwrap();
        let sa = ServiceAccountCredentials {
            r#type: "service_account".to_owned(),
            project_id: "project_id".to_owned(),
            private_key_id: "private_key_id".to_owned(),
            private_key: "private_key".to_owned(),
            client_email: "client_email".to_owned(),
            client_id: "client_id".to_owned(),
            auth_uri: "auth_uri".to_owned(),
            token_uri: "token_uri".to_owned(),
            auth_provider_x509_cert_url: "auth_provider_x509_cert_url".to_owned(),
            client_x509_cert_url: "client_x509_cert_url".to_owned(),
            scope: None,
        };

        assert_eq!(sa, actual);
    }

    #[test]
    fn test_token_is_valid_false() {
        let token = Token {
            access_token: "Hello".to_owned(),
            token_type: "token type".to_owned(),
            expiry: chrono::Utc::now(),
            scope: None,
        };

        assert!(token.is_valid().not());
        assert!(
            format!("{}", token).starts_with("Invalid Token expired at"),
            "expected an invalid token but got {}",
            token
        )
    }

    #[test]
    fn test_token_is_valid_true() {
        let token = Token {
            access_token: "Hello".to_owned(),
            token_type: "token type".to_owned(),
            expiry: chrono::Utc::now()
                + chrono::Duration::microseconds(35 * ONE_SECOND_TO_MICROSECONDS),
            scope: None,
        };

        assert!(token.is_valid());
        assert!(
            format!("{}", token).starts_with("Valid Token expires at"),
            "expected a valid token but got {}",
            token
        )
    }
}