cloud_storage_lite/token_provider/
oauth.rs

1//! A `TokenProvider` that fetches tokens using OAuth. This is the most common way
2//! to authenticate. In most cases, you'll want to wrap this in a `RenewingTokenProvider`.
3
4use std::{str::FromStr, sync::Arc};
5
6use anyhow::Context;
7use chrono::{DateTime, Duration, Utc};
8use jsonwebtoken::EncodingKey;
9use serde::{
10    de::{self, Deserializer},
11    Deserialize, Serialize,
12};
13use serde_with::serde_as;
14
15use super::{Token, TokenProvider};
16
17/// The `devstorage.full_control` scope.
18pub const SCOPE_STORAGE_FULL_CONTROL: &str =
19    "https://www.googleapis.com/auth/devstorage.full_control";
20
21/// A `TokenProvider` that fetches access tokens via OAuth using a provided service account.
22pub struct OAuthTokenProvider {
23    /// The scopes that will be assigned to the requested auth token.
24    scope: String,
25
26    service_account: ServiceAccount,
27
28    client: reqwest::Client,
29}
30
31impl OAuthTokenProvider {
32    /// Creates a new `OAuthTokenProvider` for the service account that requests tokens with
33    /// the provided scope.
34    pub fn new(
35        service_account: ServiceAccount,
36        scope: impl Into<String>,
37    ) -> Result<Self, OAuthError> {
38        Self::new_with_client(service_account, scope, Default::default())
39    }
40
41    /// Like `new` but also allows providing a `reqwest::Client`, if you have some special
42    /// network setup.
43    pub fn new_with_client(
44        service_account: ServiceAccount,
45        scope: impl Into<String>,
46        client: reqwest::Client,
47    ) -> Result<Self, OAuthError> {
48        Ok(Self {
49            scope: scope.into(),
50            service_account,
51            client,
52        })
53    }
54}
55
56#[async_trait::async_trait]
57impl TokenProvider for OAuthTokenProvider {
58    async fn get_token(&self) -> anyhow::Result<Arc<Token>> {
59        let header = jsonwebtoken::Header {
60            alg: jsonwebtoken::Algorithm::RS256,
61            ..Default::default()
62        };
63
64        let now = Utc::now();
65        let expiry = now + Duration::hours(1);
66
67        let claims = Claims {
68            iss: &self.service_account.client_email,
69            scope: &self.scope,
70            aud: &self.service_account.token_uri,
71            iat: now,
72            exp: expiry,
73        };
74
75        let client_assertion =
76            jsonwebtoken::encode(&header, &claims, &self.service_account.private_key)?;
77
78        let res = self
79            .client
80            .post(&self.service_account.token_uri)
81            .form(&[
82                ("grant_type", "urn:ietf:params:oauth:grant-type:jwt-bearer"),
83                ("assertion", &client_assertion),
84            ])
85            .send()
86            .await
87            .context("failed to request access token from Google")?;
88        let res_status = res.status();
89        let (token, expires_in) = match res.json::<OAuthResponse>().await? {
90            OAuthResponse::Token {
91                token: TokenKind::IdToken(..),
92                ..
93            } => return Err(OAuthError::InvalidScope.into()),
94            OAuthResponse::Token {
95                token: TokenKind::AccessToken(token),
96                expires_in,
97            } => (token, expires_in),
98            OAuthResponse::Error {
99                error_description, ..
100            } => {
101                return Err(OAuthError::Other(crate::api::GoogleError {
102                    status: res_status,
103                    message: error_description,
104                })
105                .into())
106            }
107        };
108
109        Ok(Arc::new(Token {
110            token,
111            expiry: now + expires_in,
112        }))
113    }
114}
115
116/// The client assertion claims.
117#[serde_as]
118#[derive(Serialize)]
119struct Claims<'a> {
120    iss: &'a str,
121
122    aud: &'a str,
123
124    scope: &'a str,
125
126    #[serde_as(as = "serde_with::TimestampSeconds")]
127    exp: DateTime<Utc>,
128
129    #[serde_as(as = "serde_with::TimestampSeconds")]
130    iat: DateTime<Utc>,
131}
132
133#[serde_as]
134#[derive(Deserialize)]
135#[serde(untagged)]
136enum OAuthResponse {
137    Token {
138        #[serde(flatten)]
139        token: TokenKind,
140
141        #[serde_as(as = "serde_with::DurationSeconds<i64>")]
142        expires_in: Duration,
143    },
144    Error {
145        error_description: String,
146    },
147}
148
149#[derive(Deserialize)]
150#[serde(rename_all = "snake_case")]
151enum TokenKind {
152    IdToken(String),
153    AccessToken(String),
154}
155
156/// An error occured while authenticating using OAuth.
157#[derive(Debug, thiserror::Error)]
158pub enum OAuthError {
159    /// The service private key was invalid and could not be used for signing.
160    #[error("invalid RSA private key: {0}")]
161    InvalidSigningKey(#[from] jsonwebtoken::errors::Error),
162
163    /// A network error occurred.
164    #[error(transparent)]
165    Http(#[from] reqwest::Error),
166
167    /// The GCP API returned some error that's not commonly encountered while using this library.
168    #[error(transparent)]
169    Other(#[from] crate::api::GoogleError),
170
171    /// (Only) An invalid scope was requested, leading to the OAuth API returning an identity
172    /// token rather than an access token.
173    #[error("received an ID token instead of an access token. ensure that the scope is correct.")]
174    InvalidScope,
175}
176
177impl From<crate::api::Error> for OAuthError {
178    fn from(api_error: crate::api::Error) -> Self {
179        match api_error {
180            crate::api::Error::Http(e) => Self::Http(e),
181            crate::api::Error::Google(e) => Self::Other(e),
182        }
183    }
184}
185
186/// A representation of a GCP service account file. Contains the information required
187/// to obtain an access token via OAuth.
188pub struct ServiceAccount {
189    client_email: String,
190    private_key: EncodingKey,
191    token_uri: String,
192}
193
194impl ServiceAccount {
195    /// Reads the service account JSON file at `path` and attempts to parse it.
196    pub fn read_from_file(path: impl AsRef<std::path::Path>) -> Result<Self, ServiceAccountError> {
197        let path = path.as_ref();
198        std::fs::read_to_string(path)
199            .map_err(|error| ServiceAccountError::Io {
200                file: path.to_path_buf(),
201                error,
202            })?
203            .parse()
204    }
205
206    /// Reads the `ServiceAccount` from the file pointed to by the
207    /// `GOOGLE_APPLICATION_CREDENTIALS` environment variable.
208    pub fn read_from_canonical_env() -> Result<Self, ServiceAccountError> {
209        let service_account_path =
210            std::env::var_os("GOOGLE_APPLICATION_CREDENTIALS").unwrap_or_default();
211        Self::read_from_file(service_account_path)
212    }
213}
214
215impl FromStr for ServiceAccount {
216    type Err = ServiceAccountError;
217    fn from_str(sa_json: &str) -> Result<Self, Self::Err> {
218        let sa: DeserializableServiceAccount = serde_json::from_str(sa_json)?;
219        Ok(Self {
220            client_email: sa.client_email,
221            private_key: jsonwebtoken::EncodingKey::from_rsa_pem(sa.private_key.as_bytes())?,
222            token_uri: sa.token_uri,
223        })
224    }
225}
226impl<'de> Deserialize<'de> for ServiceAccount {
227    fn deserialize<D: Deserializer<'de>>(d: D) -> Result<Self, D::Error> {
228        String::deserialize(d)?.parse().map_err(de::Error::custom)
229    }
230}
231
232#[derive(Deserialize)]
233struct DeserializableServiceAccount {
234    #[serde(rename = "type")]
235    _ty: ServiceAccountMarker,
236    client_email: String,
237    private_key: String,
238    token_uri: String,
239}
240
241/// The `type` in the service account JSON file.
242const SERVICE_ACCOUNT_MARKER: &str = "service_account";
243
244struct ServiceAccountMarker;
245
246impl<'de> Deserialize<'de> for ServiceAccountMarker {
247    fn deserialize<D: Deserializer<'de>>(d: D) -> Result<Self, D::Error> {
248        let ty: String = String::deserialize(d)?;
249        if ty == SERVICE_ACCOUNT_MARKER {
250            Ok(Self)
251        } else {
252            Err(de::Error::custom(&format!(
253                "provided JSON had unexpected `type` `{}`. expected `{}`.",
254                ty, SERVICE_ACCOUNT_MARKER
255            )))
256        }
257    }
258}
259
260/// An error occurring when loading the service account file.
261#[derive(Debug, thiserror::Error)]
262pub enum ServiceAccountError {
263    /// The service account file could not be read.
264    #[error("failed to read service account file `{file}`: {error}")]
265    Io {
266        /// The service account file, for informational purposes.
267        file: std::path::PathBuf,
268
269        /// The actual error that occured.
270        #[source]
271        error: std::io::Error,
272    },
273
274    /// The service account JSON could not be parsed into the expected format.
275    #[error("cound not parse service account json: {0}")]
276    Parse(#[from] serde_json::Error),
277
278    /// The private key was invalid.
279    #[error("invalid `private_key`: {0}")]
280    InvalidKey(#[from] jsonwebtoken::errors::Error),
281}
282
283#[cfg(test)]
284mod tests {
285    use super::*;
286
287    #[tokio::test]
288    async fn provides_token() {
289        let sa = ServiceAccount::read_from_canonical_env().unwrap();
290        let provider = OAuthTokenProvider::new(sa, SCOPE_STORAGE_FULL_CONTROL).unwrap();
291        provider.get_token().await.unwrap();
292    }
293
294    #[tokio::test]
295    async fn fails_sanely() {
296        let mut sa = ServiceAccount::read_from_canonical_env().unwrap();
297        sa.client_email += "q";
298        let provider = OAuthTokenProvider::new(sa, SCOPE_STORAGE_FULL_CONTROL).unwrap();
299        let err = provider
300            .get_token()
301            .await
302            .unwrap_err()
303            .downcast::<OAuthError>()
304            .unwrap();
305        assert!(matches!(
306            err,
307            OAuthError::Other(crate::api::GoogleError { .. })
308        ))
309    }
310}