1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
//! A `TokenProvider` that fetches tokens using OAuth. This is the most common way
//! to authenticate. In most cases, you'll want to wrap this in a `RenewingTokenProvider`.

use std::{str::FromStr, sync::Arc};

use chrono::{DateTime, Duration, Utc};
use jsonwebtoken::EncodingKey;
use serde::{
    de::{self, Deserializer},
    Deserialize, Serialize,
};
use serde_with::serde_as;

use super::{Token, TokenProvider};

/// The `devstorage.full_control` scope.
pub const SCOPE_STORAGE_FULL_CONTROL: &str =
    "https://www.googleapis.com/auth/devstorage.full_control";

/// A `TokenProvider` that fetches access tokens via OAuth using a provided service account.
pub struct OAuthTokenProvider {
    /// The scopes that will be assigned to the requested auth token.
    scope: String,

    service_account: ServiceAccount,

    client: reqwest::Client,
}

impl OAuthTokenProvider {
    /// Creates a new `OAuthTokenProvider` for the service account that requests tokens with
    /// the provided scope.
    pub fn new(
        service_account: ServiceAccount,
        scope: impl Into<String>,
    ) -> Result<Self, OAuthError> {
        Self::new_with_client(service_account, scope, Default::default())
    }

    /// Like `new` but also allows providing a `reqwest::Client`, if you have some special
    /// network setup.
    pub fn new_with_client(
        service_account: ServiceAccount,
        scope: impl Into<String>,
        client: reqwest::Client,
    ) -> Result<Self, OAuthError> {
        Ok(Self {
            scope: scope.into(),
            service_account,
            client,
        })
    }
}

#[async_trait::async_trait]
impl TokenProvider for OAuthTokenProvider {
    async fn get_token(&self) -> anyhow::Result<Arc<Token>> {
        let header = jsonwebtoken::Header {
            alg: jsonwebtoken::Algorithm::RS256,
            ..Default::default()
        };

        let now = Utc::now();
        let expiry = now + Duration::hours(1);

        let claims = Claims {
            iss: &self.service_account.client_email,
            scope: &self.scope,
            aud: &self.service_account.token_uri,
            iat: now,
            exp: expiry,
        };

        let client_assertion =
            jsonwebtoken::encode(&header, &claims, &self.service_account.private_key)?;

        let res = self
            .client
            .post(&self.service_account.token_uri)
            .form(&[
                ("grant_type", "urn:ietf:params:oauth:grant-type:jwt-bearer"),
                ("assertion", &client_assertion),
            ])
            .send()
            .await?;
        let res_status = res.status();
        let (token, expires_in) = match res.json::<OAuthResponse>().await? {
            OAuthResponse::Token {
                token: TokenKind::IdToken(..),
                ..
            } => return Err(OAuthError::InvalidScope.into()),
            OAuthResponse::Token {
                token: TokenKind::AccessToken(token),
                expires_in,
            } => (token, expires_in),
            OAuthResponse::Error {
                error_description, ..
            } => {
                return Err(OAuthError::Other(crate::api::GoogleError {
                    status: res_status,
                    message: error_description,
                })
                .into())
            }
        };

        Ok(Arc::new(Token {
            token,
            expiry: now + expires_in,
        }))
    }
}

/// The client assertion claims.
#[serde_as]
#[derive(Serialize)]
struct Claims<'a> {
    iss: &'a str,

    aud: &'a str,

    scope: &'a str,

    #[serde_as(as = "serde_with::TimestampSeconds")]
    exp: DateTime<Utc>,

    #[serde_as(as = "serde_with::TimestampSeconds")]
    iat: DateTime<Utc>,
}

#[serde_as]
#[derive(Deserialize)]
#[serde(untagged)]
enum OAuthResponse {
    Token {
        #[serde(flatten)]
        token: TokenKind,

        #[serde_as(as = "serde_with::DurationSeconds<i64>")]
        expires_in: Duration,
    },
    Error {
        error_description: String,
    },
}

#[derive(Deserialize)]
#[serde(rename_all = "snake_case")]
enum TokenKind {
    IdToken(String),
    AccessToken(String),
}

/// An error occured while authenticating using OAuth.
#[derive(Debug, thiserror::Error)]
pub enum OAuthError {
    /// The service private key was invalid and could not be used for signing.
    #[error("invalid RSA private key: {0}")]
    InvalidSigningKey(#[from] jsonwebtoken::errors::Error),

    /// A network error occurred.
    #[error(transparent)]
    Http(#[from] reqwest::Error),

    /// The GCP API returned some error that's not commonly encountered while using this library.
    #[error(transparent)]
    Other(#[from] crate::api::GoogleError),

    /// (Only) An invalid scope was requested, leading to the OAuth API returning an identity
    /// token rather than an access token.
    #[error("received an ID token instead of an access token. ensure that the scope is correct.")]
    InvalidScope,
}

impl From<crate::api::Error> for OAuthError {
    fn from(api_error: crate::api::Error) -> Self {
        match api_error {
            crate::api::Error::Http(e) => Self::Http(e),
            crate::api::Error::Google(e) => Self::Other(e),
        }
    }
}

/// A representation of a GCP service account file. Contains the information required
/// to obtain an access token via OAuth.
pub struct ServiceAccount {
    client_email: String,
    private_key: EncodingKey,
    token_uri: String,
}

impl ServiceAccount {
    /// Reads the service account JSON file at `path` and attempts to parse it.
    pub fn read_from_file(path: impl AsRef<std::path::Path>) -> Result<Self, ServiceAccountError> {
        let path = path.as_ref();
        std::fs::read_to_string(path)
            .map_err(|error| ServiceAccountError::Io {
                file: path.to_path_buf(),
                error,
            })?
            .parse()
    }

    /// Reads the `ServiceAccount` from the file pointed to by the
    /// `GOOGLE_APPLICATION_CREDENTIALS` environment variable.
    pub fn read_from_canonical_env() -> Result<Self, ServiceAccountError> {
        let service_account_path =
            std::env::var_os("GOOGLE_APPLICATION_CREDENTIALS").unwrap_or_default();
        Self::read_from_file(service_account_path)
    }
}

impl FromStr for ServiceAccount {
    type Err = ServiceAccountError;
    fn from_str(sa_json: &str) -> Result<Self, Self::Err> {
        let sa: DeserializableServiceAccount = serde_json::from_str(sa_json)?;
        Ok(Self {
            client_email: sa.client_email,
            private_key: jsonwebtoken::EncodingKey::from_rsa_pem(sa.private_key.as_bytes())?,
            token_uri: sa.token_uri,
        })
    }
}
impl<'de> Deserialize<'de> for ServiceAccount {
    fn deserialize<D: Deserializer<'de>>(d: D) -> Result<Self, D::Error> {
        String::deserialize(d)?.parse().map_err(de::Error::custom)
    }
}

#[derive(Deserialize)]
struct DeserializableServiceAccount {
    #[serde(rename = "type")]
    _ty: ServiceAccountMarker,
    client_email: String,
    private_key: String,
    token_uri: String,
}

/// The `type` in the service account JSON file.
const SERVICE_ACCOUNT_MARKER: &str = "service_account";

struct ServiceAccountMarker;

impl<'de> Deserialize<'de> for ServiceAccountMarker {
    fn deserialize<D: Deserializer<'de>>(d: D) -> Result<Self, D::Error> {
        let ty: String = String::deserialize(d)?;
        if ty == SERVICE_ACCOUNT_MARKER {
            Ok(Self)
        } else {
            Err(de::Error::custom(&format!(
                "provided JSON had unexpected `type` `{}`. expected `{}`.",
                ty, SERVICE_ACCOUNT_MARKER
            )))
        }
    }
}

/// An error occurring when loading the service account file.
#[derive(Debug, thiserror::Error)]
pub enum ServiceAccountError {
    /// The service account file could not be read.
    #[error("failed to read service account file `{file}`: {error}")]
    Io {
        /// The service account file, for informational purposes.
        file: std::path::PathBuf,

        /// The actual error that occured.
        #[source]
        error: std::io::Error,
    },

    /// The service account JSON could not be parsed into the expected format.
    #[error("cound not parse service account json: {0}")]
    Parse(#[from] serde_json::Error),

    /// The private key was invalid.
    #[error("invalid `private_key`: {0}")]
    InvalidKey(#[from] jsonwebtoken::errors::Error),
}

#[cfg(test)]
mod tests {
    use super::*;

    #[tokio::test]
    async fn provides_token() {
        let sa = ServiceAccount::read_from_canonical_env().unwrap();
        let provider = OAuthTokenProvider::new(sa, SCOPE_STORAGE_FULL_CONTROL).unwrap();
        provider.get_token().await.unwrap();
    }

    #[tokio::test]
    async fn fails_sanely() {
        let mut sa = ServiceAccount::read_from_canonical_env().unwrap();
        sa.client_email += "q";
        let provider = OAuthTokenProvider::new(sa, SCOPE_STORAGE_FULL_CONTROL).unwrap();
        let err = provider
            .get_token()
            .await
            .unwrap_err()
            .downcast::<OAuthError>()
            .unwrap();
        assert!(matches!(
            err,
            OAuthError::Other(crate::api::GoogleError { .. })
        ))
    }
}