1use crate::gcp::DeserializedResponse;
2use crate::Client;
3
4use super::{Error, TokenResult};
5use chrono::{DateTime, Utc};
6use serde::{Deserialize, Deserializer, Serialize};
7use std::{
8 fmt::{Debug, Display},
9 path::Path,
10};
11use urlencoding::encode;
12
13#[derive(Deserialize, Debug, Clone)]
14pub struct Token {
15 access_token: String,
16 #[allow(dead_code)]
17 token_type: String,
18 #[serde(
19 deserialize_with = "from_expires_in",
20 rename(deserialize = "expires_in")
21 )]
22 expiry: DateTime<Utc>,
23 #[serde(default)]
24 scope: Option<String>,
25}
26
27const ONE_SECOND_TO_MICROSECONDS: i64 = 1_000_000;
28
29fn from_expires_in<'de, D>(deserializer: D) -> std::result::Result<DateTime<Utc>, D::Error>
30where
31 D: Deserializer<'de>,
32{
33 let expires_in: i64 = Deserialize::deserialize(deserializer)?;
34 Ok(Utc::now() + chrono::Duration::microseconds(expires_in * ONE_SECOND_TO_MICROSECONDS))
35}
36
37impl Display for Token {
38 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
39 if self.is_valid() {
40 write!(f, "Valid Token expires at {}", self.expiry)
41 } else {
42 write!(f, "Invalid Token expired at {}", self.expiry)
43 }
44 }
45}
46
47pub type AccessToken = String;
48
49impl Token {
50 pub fn access_token(&self) -> AccessToken {
51 self.access_token.to_owned()
52 }
53
54 pub fn is_valid(&self) -> bool {
55 self.expiry - chrono::Duration::microseconds(30 * ONE_SECOND_TO_MICROSECONDS) > Utc::now()
56 }
57
58 pub fn with_scope(mut self, scope: String) -> Self {
59 self.scope = Some(scope);
60 self
61 }
62}
63
64#[async_trait::async_trait]
65pub trait TokenGenerator: Sync + Send {
66 async fn get(&self, client: &Client) -> TokenResult<Token>;
67}
68
69impl Debug for dyn TokenGenerator {
70 fn fmt(&self, _: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
71 Ok(())
72 }
73}
74
75#[async_trait::async_trait]
76impl TokenGenerator for AuthorizedUserCredentials {
77 async fn get(&self, client: &Client) -> TokenResult<Token> {
78 let req = self;
79 let token: DeserializedResponse<Token> = client
80 .client
81 .post("https://accounts.google.com/o/oauth2/token")
82 .json(&req)
83 .send()
84 .await
85 .map_err(Error::HttpError)?
86 .json()
87 .await
88 .map_err(Error::HttpError)?;
89 token
90 .into_result()
91 .map_err(super::Error::unexpected_api_response::<Token>)
92 }
93}
94
95#[async_trait::async_trait]
96impl TokenGenerator for ServiceAccountCredentials {
97 async fn get(&self, client: &Client) -> TokenResult<Token> {
98 let now = chrono::Utc::now().timestamp();
99 let exp = now + 3600;
100
101 let scope = self.scope.to_owned().ok_or(super::Error::MissingScope)?;
102
103 let claims = Claims {
104 iss: self.client_email.as_str(),
105 scope: scope.as_str(),
106 aud: "https://www.googleapis.com/oauth2/v4/token",
107 exp,
108 iat: now,
109 };
110 let header = jsonwebtoken::Header {
111 alg: jsonwebtoken::Algorithm::RS256,
112 ..Default::default()
113 };
114 let private_key = jsonwebtoken::EncodingKey::from_rsa_pem(self.private_key.as_bytes())
115 .map_err(Error::JWTError)?;
116 let jwt = jsonwebtoken::encode(&header, &claims, &private_key).map_err(Error::JWTError)?;
117 let form = [
118 ("grant_type", "urn:ietf:params:oauth:grant-type:jwt-bearer"),
119 ("assertion", &jwt),
120 ];
121
122 let token: DeserializedResponse<Token> = client
123 .client
124 .post("https://www.googleapis.com/oauth2/v4/token")
125 .form(&form)
126 .send()
127 .await
128 .map_err(Error::HttpError)?
129 .json()
130 .await
131 .map_err(Error::HttpError)?;
132 token
133 .into_result()
134 .map(|t| t.with_scope(scope))
135 .map_err(super::Error::unexpected_api_response::<Token>)
136 }
137}
138
139#[async_trait::async_trait]
140impl TokenGenerator for GoogleMetadataServerCredentials {
141 async fn get(&self, client: &Client) -> TokenResult<Token> {
142 const DEFAULT_TOKEN_GCP_URI: &str = "http://metadata.google.internal/computeMetadata/v1/instance/service-accounts/default/token";
143 let uri = match self.scope {
144 None => DEFAULT_TOKEN_GCP_URI.to_owned(),
145 Some(ref scope) => format!(
146 "{}?{}",
147 DEFAULT_TOKEN_GCP_URI,
148 encode(format!("scopes={}", scope).as_str())
149 ),
150 };
151
152 let token: DeserializedResponse<Token> = client
153 .client
154 .get(uri)
155 .header("Metadata-Flavor", "Google")
156 .send()
157 .await
158 .map_err(Error::HttpError)?
159 .json()
160 .await
161 .map_err(Error::HttpError)?;
162 token
163 .into_result()
164 .map_err(super::Error::unexpected_api_response::<Token>)
165 }
166}
167
168fn from_str<T>(str: &str) -> TokenResult<T>
169where
170 T: serde::de::DeserializeOwned,
171{
172 serde_json::from_str(str).map_err(Error::deserialization_error::<T>)
173}
174
175async fn from_file<T, U>(file_path: T) -> TokenResult<U>
176where
177 T: AsRef<Path>,
178 U: serde::de::DeserializeOwned,
179{
180 tokio::fs::read_to_string(file_path.as_ref())
181 .await
182 .map_err(|err| Error::io_error("error while reading file", file_path.as_ref(), err))
183 .and_then(|f| from_str(f.as_str()))
184}
185
186async fn default<T>() -> TokenResult<T>
187where
188 T: serde::de::DeserializeOwned,
189{
190 let default_path = {
191 let key = "GOOGLE_APPLICATION_CREDENTIALS";
192 std::env::var(key).map_err(|err| Error::env_var_error(key, err))?
193 };
194 from_file(default_path).await
195}
196
197#[derive(Serialize, Debug)]
198struct Claims<'a> {
199 iss: &'a str,
200 aud: &'a str,
201 exp: i64,
202 iat: i64,
203 scope: &'a str,
204}
205
206#[derive(Serialize, Deserialize, Debug, PartialEq, Eq)]
207pub struct AuthorizedUserCredentials {
208 client_id: String,
209 client_secret: String,
210 refresh_token: String,
211 #[serde(default = "refresh_token")]
212 grant_type: String,
213}
214
215fn refresh_token() -> String {
216 "refresh_token".to_owned()
217}
218
219impl AuthorizedUserCredentials {
220 pub fn from(s: &str) -> TokenResult<Self> {
221 from_str(s)
222 }
223
224 pub async fn from_file<T>(file_path: T) -> TokenResult<Self>
225 where
226 T: AsRef<Path>,
227 {
228 from_file(file_path).await
229 }
230
231 pub async fn default() -> TokenResult<Self> {
232 default().await
233 }
234}
235
236#[derive(Serialize, Deserialize, Debug, PartialEq, Eq)]
237pub struct ServiceAccountCredentials {
238 r#type: String,
239 project_id: String,
240 private_key_id: String,
241 private_key: String,
242 client_email: String,
243 client_id: String,
244 auth_uri: String,
245 token_uri: String,
246 auth_provider_x509_cert_url: String,
247 client_x509_cert_url: String,
248 #[serde(default)]
249 scope: Option<String>,
250}
251
252impl ServiceAccountCredentials {
253 pub fn from(s: &str) -> TokenResult<Self> {
254 from_str(s)
255 }
256
257 pub async fn from_file<T>(file_path: T) -> TokenResult<Self>
258 where
259 T: AsRef<Path>,
260 {
261 from_file(file_path).await
262 }
263
264 pub async fn default() -> TokenResult<Self> {
265 default().await
266 }
267
268 pub fn with_scope(mut self, scope: &str) -> Self {
269 self.scope = Some(scope.to_owned());
270 self
271 }
272}
273
274#[derive(Serialize, Deserialize, Debug, PartialEq, Eq)]
275pub struct GoogleMetadataServerCredentials {
276 scope: Option<String>,
277}
278
279impl GoogleMetadataServerCredentials {
280 pub fn new() -> TokenResult<Self> {
281 Ok(GoogleMetadataServerCredentials { scope: None })
282 }
283 pub fn with_scope(mut self, scope: &str) -> Self {
284 self.scope = Some(scope.to_owned());
285 self
286 }
287}
288
289#[cfg(test)]
290mod tests {
291 use std::ops::Not;
292
293 use crate::gcp::oauth2::token::*;
294
295 #[test]
296 fn token_from_json_test() {
297 let raw = r#"{
298 "access_token": "access_token",
299 "expires_in": 3599,
300 "scope": "scope",
301 "token_type": "Bearer",
302 "id_token": "id_token"
303 }"#;
304
305 let actual: Token = serde_json::from_str(raw).unwrap();
306 assert_eq!("access_token", actual.access_token);
307 assert_eq!("Bearer", actual.token_type);
308 assert!(actual.expiry > Utc::now());
309 }
310 #[test]
311 fn token_from_authorized_user_json_test() {
312 let actual = super::from_str(
313 r#"{
314 "client_id": "client_id",
315 "client_secret": "client_secret",
316 "quota_project_id": "quota_project_id",
317 "refresh_token": "refresh_token",
318 "type": "authorized_user"
319 }"#,
320 )
321 .unwrap();
322 let au = AuthorizedUserCredentials {
323 client_id: "client_id".to_owned(),
324 client_secret: "client_secret".to_owned(),
325 refresh_token: "refresh_token".to_owned(),
326 grant_type: "refresh_token".to_owned(),
327 };
328
329 assert_eq!(au, actual);
330 }
331
332 #[test]
333 fn token_from_service_account_json_test() {
334 let actual = super::from_str(
335 r#"{
336 "type": "service_account",
337 "project_id": "project_id",
338 "private_key_id": "private_key_id",
339 "private_key": "private_key",
340 "client_email": "client_email",
341 "client_id": "client_id",
342 "auth_uri": "auth_uri",
343 "token_uri": "token_uri",
344 "auth_provider_x509_cert_url": "auth_provider_x509_cert_url",
345 "client_x509_cert_url": "client_x509_cert_url"
346 }"#,
347 )
348 .unwrap();
349 let sa = ServiceAccountCredentials {
350 r#type: "service_account".to_owned(),
351 project_id: "project_id".to_owned(),
352 private_key_id: "private_key_id".to_owned(),
353 private_key: "private_key".to_owned(),
354 client_email: "client_email".to_owned(),
355 client_id: "client_id".to_owned(),
356 auth_uri: "auth_uri".to_owned(),
357 token_uri: "token_uri".to_owned(),
358 auth_provider_x509_cert_url: "auth_provider_x509_cert_url".to_owned(),
359 client_x509_cert_url: "client_x509_cert_url".to_owned(),
360 scope: None,
361 };
362
363 assert_eq!(sa, actual);
364 }
365
366 #[test]
367 fn test_token_is_valid_false() {
368 let token = Token {
369 access_token: "Hello".to_owned(),
370 token_type: "token type".to_owned(),
371 expiry: chrono::Utc::now(),
372 scope: None,
373 };
374
375 assert!(token.is_valid().not());
376 assert!(
377 format!("{}", token).starts_with("Invalid Token expired at"),
378 "expected an invalid token but got {}",
379 token
380 )
381 }
382
383 #[test]
384 fn test_token_is_valid_true() {
385 let token = Token {
386 access_token: "Hello".to_owned(),
387 token_type: "token type".to_owned(),
388 expiry: chrono::Utc::now()
389 + chrono::Duration::microseconds(35 * ONE_SECOND_TO_MICROSECONDS),
390 scope: None,
391 };
392
393 assert!(token.is_valid());
394 assert!(
395 format!("{}", token).starts_with("Valid Token expires at"),
396 "expected a valid token but got {}",
397 token
398 )
399 }
400}