cloud_storage_lite/token_provider/
oauth.rs1use std::{str::FromStr, sync::Arc};
5
6use anyhow::Context;
7use chrono::{DateTime, Duration, Utc};
8use jsonwebtoken::EncodingKey;
9use serde::{
10 de::{self, Deserializer},
11 Deserialize, Serialize,
12};
13use serde_with::serde_as;
14
15use super::{Token, TokenProvider};
16
17pub const SCOPE_STORAGE_FULL_CONTROL: &str =
19 "https://www.googleapis.com/auth/devstorage.full_control";
20
21pub struct OAuthTokenProvider {
23 scope: String,
25
26 service_account: ServiceAccount,
27
28 client: reqwest::Client,
29}
30
31impl OAuthTokenProvider {
32 pub fn new(
35 service_account: ServiceAccount,
36 scope: impl Into<String>,
37 ) -> Result<Self, OAuthError> {
38 Self::new_with_client(service_account, scope, Default::default())
39 }
40
41 pub fn new_with_client(
44 service_account: ServiceAccount,
45 scope: impl Into<String>,
46 client: reqwest::Client,
47 ) -> Result<Self, OAuthError> {
48 Ok(Self {
49 scope: scope.into(),
50 service_account,
51 client,
52 })
53 }
54}
55
56#[async_trait::async_trait]
57impl TokenProvider for OAuthTokenProvider {
58 async fn get_token(&self) -> anyhow::Result<Arc<Token>> {
59 let header = jsonwebtoken::Header {
60 alg: jsonwebtoken::Algorithm::RS256,
61 ..Default::default()
62 };
63
64 let now = Utc::now();
65 let expiry = now + Duration::hours(1);
66
67 let claims = Claims {
68 iss: &self.service_account.client_email,
69 scope: &self.scope,
70 aud: &self.service_account.token_uri,
71 iat: now,
72 exp: expiry,
73 };
74
75 let client_assertion =
76 jsonwebtoken::encode(&header, &claims, &self.service_account.private_key)?;
77
78 let res = self
79 .client
80 .post(&self.service_account.token_uri)
81 .form(&[
82 ("grant_type", "urn:ietf:params:oauth:grant-type:jwt-bearer"),
83 ("assertion", &client_assertion),
84 ])
85 .send()
86 .await
87 .context("failed to request access token from Google")?;
88 let res_status = res.status();
89 let (token, expires_in) = match res.json::<OAuthResponse>().await? {
90 OAuthResponse::Token {
91 token: TokenKind::IdToken(..),
92 ..
93 } => return Err(OAuthError::InvalidScope.into()),
94 OAuthResponse::Token {
95 token: TokenKind::AccessToken(token),
96 expires_in,
97 } => (token, expires_in),
98 OAuthResponse::Error {
99 error_description, ..
100 } => {
101 return Err(OAuthError::Other(crate::api::GoogleError {
102 status: res_status,
103 message: error_description,
104 })
105 .into())
106 }
107 };
108
109 Ok(Arc::new(Token {
110 token,
111 expiry: now + expires_in,
112 }))
113 }
114}
115
116#[serde_as]
118#[derive(Serialize)]
119struct Claims<'a> {
120 iss: &'a str,
121
122 aud: &'a str,
123
124 scope: &'a str,
125
126 #[serde_as(as = "serde_with::TimestampSeconds")]
127 exp: DateTime<Utc>,
128
129 #[serde_as(as = "serde_with::TimestampSeconds")]
130 iat: DateTime<Utc>,
131}
132
133#[serde_as]
134#[derive(Deserialize)]
135#[serde(untagged)]
136enum OAuthResponse {
137 Token {
138 #[serde(flatten)]
139 token: TokenKind,
140
141 #[serde_as(as = "serde_with::DurationSeconds<i64>")]
142 expires_in: Duration,
143 },
144 Error {
145 error_description: String,
146 },
147}
148
149#[derive(Deserialize)]
150#[serde(rename_all = "snake_case")]
151enum TokenKind {
152 IdToken(String),
153 AccessToken(String),
154}
155
156#[derive(Debug, thiserror::Error)]
158pub enum OAuthError {
159 #[error("invalid RSA private key: {0}")]
161 InvalidSigningKey(#[from] jsonwebtoken::errors::Error),
162
163 #[error(transparent)]
165 Http(#[from] reqwest::Error),
166
167 #[error(transparent)]
169 Other(#[from] crate::api::GoogleError),
170
171 #[error("received an ID token instead of an access token. ensure that the scope is correct.")]
174 InvalidScope,
175}
176
177impl From<crate::api::Error> for OAuthError {
178 fn from(api_error: crate::api::Error) -> Self {
179 match api_error {
180 crate::api::Error::Http(e) => Self::Http(e),
181 crate::api::Error::Google(e) => Self::Other(e),
182 }
183 }
184}
185
186pub struct ServiceAccount {
189 client_email: String,
190 private_key: EncodingKey,
191 token_uri: String,
192}
193
194impl ServiceAccount {
195 pub fn read_from_file(path: impl AsRef<std::path::Path>) -> Result<Self, ServiceAccountError> {
197 let path = path.as_ref();
198 std::fs::read_to_string(path)
199 .map_err(|error| ServiceAccountError::Io {
200 file: path.to_path_buf(),
201 error,
202 })?
203 .parse()
204 }
205
206 pub fn read_from_canonical_env() -> Result<Self, ServiceAccountError> {
209 let service_account_path =
210 std::env::var_os("GOOGLE_APPLICATION_CREDENTIALS").unwrap_or_default();
211 Self::read_from_file(service_account_path)
212 }
213}
214
215impl FromStr for ServiceAccount {
216 type Err = ServiceAccountError;
217 fn from_str(sa_json: &str) -> Result<Self, Self::Err> {
218 let sa: DeserializableServiceAccount = serde_json::from_str(sa_json)?;
219 Ok(Self {
220 client_email: sa.client_email,
221 private_key: jsonwebtoken::EncodingKey::from_rsa_pem(sa.private_key.as_bytes())?,
222 token_uri: sa.token_uri,
223 })
224 }
225}
226impl<'de> Deserialize<'de> for ServiceAccount {
227 fn deserialize<D: Deserializer<'de>>(d: D) -> Result<Self, D::Error> {
228 String::deserialize(d)?.parse().map_err(de::Error::custom)
229 }
230}
231
232#[derive(Deserialize)]
233struct DeserializableServiceAccount {
234 #[serde(rename = "type")]
235 _ty: ServiceAccountMarker,
236 client_email: String,
237 private_key: String,
238 token_uri: String,
239}
240
241const SERVICE_ACCOUNT_MARKER: &str = "service_account";
243
244struct ServiceAccountMarker;
245
246impl<'de> Deserialize<'de> for ServiceAccountMarker {
247 fn deserialize<D: Deserializer<'de>>(d: D) -> Result<Self, D::Error> {
248 let ty: String = String::deserialize(d)?;
249 if ty == SERVICE_ACCOUNT_MARKER {
250 Ok(Self)
251 } else {
252 Err(de::Error::custom(&format!(
253 "provided JSON had unexpected `type` `{}`. expected `{}`.",
254 ty, SERVICE_ACCOUNT_MARKER
255 )))
256 }
257 }
258}
259
260#[derive(Debug, thiserror::Error)]
262pub enum ServiceAccountError {
263 #[error("failed to read service account file `{file}`: {error}")]
265 Io {
266 file: std::path::PathBuf,
268
269 #[source]
271 error: std::io::Error,
272 },
273
274 #[error("cound not parse service account json: {0}")]
276 Parse(#[from] serde_json::Error),
277
278 #[error("invalid `private_key`: {0}")]
280 InvalidKey(#[from] jsonwebtoken::errors::Error),
281}
282
283#[cfg(test)]
284mod tests {
285 use super::*;
286
287 #[tokio::test]
288 async fn provides_token() {
289 let sa = ServiceAccount::read_from_canonical_env().unwrap();
290 let provider = OAuthTokenProvider::new(sa, SCOPE_STORAGE_FULL_CONTROL).unwrap();
291 provider.get_token().await.unwrap();
292 }
293
294 #[tokio::test]
295 async fn fails_sanely() {
296 let mut sa = ServiceAccount::read_from_canonical_env().unwrap();
297 sa.client_email += "q";
298 let provider = OAuthTokenProvider::new(sa, SCOPE_STORAGE_FULL_CONTROL).unwrap();
299 let err = provider
300 .get_token()
301 .await
302 .unwrap_err()
303 .downcast::<OAuthError>()
304 .unwrap();
305 assert!(matches!(
306 err,
307 OAuthError::Other(crate::api::GoogleError { .. })
308 ))
309 }
310}