gcs_rsync/gcp/oauth2/
token.rs

1use crate::gcp::DeserializedResponse;
2use crate::Client;
3
4use super::{Error, TokenResult};
5use chrono::{DateTime, Utc};
6use serde::{Deserialize, Deserializer, Serialize};
7use std::{
8    fmt::{Debug, Display},
9    path::Path,
10};
11use urlencoding::encode;
12
13#[derive(Deserialize, Debug, Clone)]
14pub struct Token {
15    access_token: String,
16    #[allow(dead_code)]
17    token_type: String,
18    #[serde(
19        deserialize_with = "from_expires_in",
20        rename(deserialize = "expires_in")
21    )]
22    expiry: DateTime<Utc>,
23    #[serde(default)]
24    scope: Option<String>,
25}
26
27const ONE_SECOND_TO_MICROSECONDS: i64 = 1_000_000;
28
29fn from_expires_in<'de, D>(deserializer: D) -> std::result::Result<DateTime<Utc>, D::Error>
30where
31    D: Deserializer<'de>,
32{
33    let expires_in: i64 = Deserialize::deserialize(deserializer)?;
34    Ok(Utc::now() + chrono::Duration::microseconds(expires_in * ONE_SECOND_TO_MICROSECONDS))
35}
36
37impl Display for Token {
38    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
39        if self.is_valid() {
40            write!(f, "Valid Token expires at {}", self.expiry)
41        } else {
42            write!(f, "Invalid Token expired at {}", self.expiry)
43        }
44    }
45}
46
47pub type AccessToken = String;
48
49impl Token {
50    pub fn access_token(&self) -> AccessToken {
51        self.access_token.to_owned()
52    }
53
54    pub fn is_valid(&self) -> bool {
55        self.expiry - chrono::Duration::microseconds(30 * ONE_SECOND_TO_MICROSECONDS) > Utc::now()
56    }
57
58    pub fn with_scope(mut self, scope: String) -> Self {
59        self.scope = Some(scope);
60        self
61    }
62}
63
64#[async_trait::async_trait]
65pub trait TokenGenerator: Sync + Send {
66    async fn get(&self, client: &Client) -> TokenResult<Token>;
67}
68
69impl Debug for dyn TokenGenerator {
70    fn fmt(&self, _: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
71        Ok(())
72    }
73}
74
75#[async_trait::async_trait]
76impl TokenGenerator for AuthorizedUserCredentials {
77    async fn get(&self, client: &Client) -> TokenResult<Token> {
78        let req = self;
79        let token: DeserializedResponse<Token> = client
80            .client
81            .post("https://accounts.google.com/o/oauth2/token")
82            .json(&req)
83            .send()
84            .await
85            .map_err(Error::HttpError)?
86            .json()
87            .await
88            .map_err(Error::HttpError)?;
89        token
90            .into_result()
91            .map_err(super::Error::unexpected_api_response::<Token>)
92    }
93}
94
95#[async_trait::async_trait]
96impl TokenGenerator for ServiceAccountCredentials {
97    async fn get(&self, client: &Client) -> TokenResult<Token> {
98        let now = chrono::Utc::now().timestamp();
99        let exp = now + 3600;
100
101        let scope = self.scope.to_owned().ok_or(super::Error::MissingScope)?;
102
103        let claims = Claims {
104            iss: self.client_email.as_str(),
105            scope: scope.as_str(),
106            aud: "https://www.googleapis.com/oauth2/v4/token",
107            exp,
108            iat: now,
109        };
110        let header = jsonwebtoken::Header {
111            alg: jsonwebtoken::Algorithm::RS256,
112            ..Default::default()
113        };
114        let private_key = jsonwebtoken::EncodingKey::from_rsa_pem(self.private_key.as_bytes())
115            .map_err(Error::JWTError)?;
116        let jwt = jsonwebtoken::encode(&header, &claims, &private_key).map_err(Error::JWTError)?;
117        let form = [
118            ("grant_type", "urn:ietf:params:oauth:grant-type:jwt-bearer"),
119            ("assertion", &jwt),
120        ];
121
122        let token: DeserializedResponse<Token> = client
123            .client
124            .post("https://www.googleapis.com/oauth2/v4/token")
125            .form(&form)
126            .send()
127            .await
128            .map_err(Error::HttpError)?
129            .json()
130            .await
131            .map_err(Error::HttpError)?;
132        token
133            .into_result()
134            .map(|t| t.with_scope(scope))
135            .map_err(super::Error::unexpected_api_response::<Token>)
136    }
137}
138
139#[async_trait::async_trait]
140impl TokenGenerator for GoogleMetadataServerCredentials {
141    async fn get(&self, client: &Client) -> TokenResult<Token> {
142        const DEFAULT_TOKEN_GCP_URI: &str = "http://metadata.google.internal/computeMetadata/v1/instance/service-accounts/default/token";
143        let uri = match self.scope {
144            None => DEFAULT_TOKEN_GCP_URI.to_owned(),
145            Some(ref scope) => format!(
146                "{}?{}",
147                DEFAULT_TOKEN_GCP_URI,
148                encode(format!("scopes={}", scope).as_str())
149            ),
150        };
151
152        let token: DeserializedResponse<Token> = client
153            .client
154            .get(uri)
155            .header("Metadata-Flavor", "Google")
156            .send()
157            .await
158            .map_err(Error::HttpError)?
159            .json()
160            .await
161            .map_err(Error::HttpError)?;
162        token
163            .into_result()
164            .map_err(super::Error::unexpected_api_response::<Token>)
165    }
166}
167
168fn from_str<T>(str: &str) -> TokenResult<T>
169where
170    T: serde::de::DeserializeOwned,
171{
172    serde_json::from_str(str).map_err(Error::deserialization_error::<T>)
173}
174
175async fn from_file<T, U>(file_path: T) -> TokenResult<U>
176where
177    T: AsRef<Path>,
178    U: serde::de::DeserializeOwned,
179{
180    tokio::fs::read_to_string(file_path.as_ref())
181        .await
182        .map_err(|err| Error::io_error("error while reading file", file_path.as_ref(), err))
183        .and_then(|f| from_str(f.as_str()))
184}
185
186async fn default<T>() -> TokenResult<T>
187where
188    T: serde::de::DeserializeOwned,
189{
190    let default_path = {
191        let key = "GOOGLE_APPLICATION_CREDENTIALS";
192        std::env::var(key).map_err(|err| Error::env_var_error(key, err))?
193    };
194    from_file(default_path).await
195}
196
197#[derive(Serialize, Debug)]
198struct Claims<'a> {
199    iss: &'a str,
200    aud: &'a str,
201    exp: i64,
202    iat: i64,
203    scope: &'a str,
204}
205
206#[derive(Serialize, Deserialize, Debug, PartialEq, Eq)]
207pub struct AuthorizedUserCredentials {
208    client_id: String,
209    client_secret: String,
210    refresh_token: String,
211    #[serde(default = "refresh_token")]
212    grant_type: String,
213}
214
215fn refresh_token() -> String {
216    "refresh_token".to_owned()
217}
218
219impl AuthorizedUserCredentials {
220    pub fn from(s: &str) -> TokenResult<Self> {
221        from_str(s)
222    }
223
224    pub async fn from_file<T>(file_path: T) -> TokenResult<Self>
225    where
226        T: AsRef<Path>,
227    {
228        from_file(file_path).await
229    }
230
231    pub async fn default() -> TokenResult<Self> {
232        default().await
233    }
234}
235
236#[derive(Serialize, Deserialize, Debug, PartialEq, Eq)]
237pub struct ServiceAccountCredentials {
238    r#type: String,
239    project_id: String,
240    private_key_id: String,
241    private_key: String,
242    client_email: String,
243    client_id: String,
244    auth_uri: String,
245    token_uri: String,
246    auth_provider_x509_cert_url: String,
247    client_x509_cert_url: String,
248    #[serde(default)]
249    scope: Option<String>,
250}
251
252impl ServiceAccountCredentials {
253    pub fn from(s: &str) -> TokenResult<Self> {
254        from_str(s)
255    }
256
257    pub async fn from_file<T>(file_path: T) -> TokenResult<Self>
258    where
259        T: AsRef<Path>,
260    {
261        from_file(file_path).await
262    }
263
264    pub async fn default() -> TokenResult<Self> {
265        default().await
266    }
267
268    pub fn with_scope(mut self, scope: &str) -> Self {
269        self.scope = Some(scope.to_owned());
270        self
271    }
272}
273
274#[derive(Serialize, Deserialize, Debug, PartialEq, Eq)]
275pub struct GoogleMetadataServerCredentials {
276    scope: Option<String>,
277}
278
279impl GoogleMetadataServerCredentials {
280    pub fn new() -> TokenResult<Self> {
281        Ok(GoogleMetadataServerCredentials { scope: None })
282    }
283    pub fn with_scope(mut self, scope: &str) -> Self {
284        self.scope = Some(scope.to_owned());
285        self
286    }
287}
288
289#[cfg(test)]
290mod tests {
291    use std::ops::Not;
292
293    use crate::gcp::oauth2::token::*;
294
295    #[test]
296    fn token_from_json_test() {
297        let raw = r#"{
298            "access_token": "access_token",
299            "expires_in": 3599,
300            "scope": "scope",
301            "token_type": "Bearer",
302            "id_token": "id_token"
303        }"#;
304
305        let actual: Token = serde_json::from_str(raw).unwrap();
306        assert_eq!("access_token", actual.access_token);
307        assert_eq!("Bearer", actual.token_type);
308        assert!(actual.expiry > Utc::now());
309    }
310    #[test]
311    fn token_from_authorized_user_json_test() {
312        let actual = super::from_str(
313            r#"{
314                   "client_id": "client_id",
315                   "client_secret": "client_secret",
316                   "quota_project_id": "quota_project_id",
317                   "refresh_token": "refresh_token",
318                   "type": "authorized_user"
319            }"#,
320        )
321        .unwrap();
322        let au = AuthorizedUserCredentials {
323            client_id: "client_id".to_owned(),
324            client_secret: "client_secret".to_owned(),
325            refresh_token: "refresh_token".to_owned(),
326            grant_type: "refresh_token".to_owned(),
327        };
328
329        assert_eq!(au, actual);
330    }
331
332    #[test]
333    fn token_from_service_account_json_test() {
334        let actual = super::from_str(
335            r#"{
336                "type": "service_account",
337                "project_id": "project_id",
338                "private_key_id": "private_key_id",
339                "private_key": "private_key",
340                "client_email": "client_email",
341                "client_id": "client_id",
342                "auth_uri": "auth_uri",
343                "token_uri": "token_uri",
344                "auth_provider_x509_cert_url": "auth_provider_x509_cert_url",
345                "client_x509_cert_url": "client_x509_cert_url"
346            }"#,
347        )
348        .unwrap();
349        let sa = ServiceAccountCredentials {
350            r#type: "service_account".to_owned(),
351            project_id: "project_id".to_owned(),
352            private_key_id: "private_key_id".to_owned(),
353            private_key: "private_key".to_owned(),
354            client_email: "client_email".to_owned(),
355            client_id: "client_id".to_owned(),
356            auth_uri: "auth_uri".to_owned(),
357            token_uri: "token_uri".to_owned(),
358            auth_provider_x509_cert_url: "auth_provider_x509_cert_url".to_owned(),
359            client_x509_cert_url: "client_x509_cert_url".to_owned(),
360            scope: None,
361        };
362
363        assert_eq!(sa, actual);
364    }
365
366    #[test]
367    fn test_token_is_valid_false() {
368        let token = Token {
369            access_token: "Hello".to_owned(),
370            token_type: "token type".to_owned(),
371            expiry: chrono::Utc::now(),
372            scope: None,
373        };
374
375        assert!(token.is_valid().not());
376        assert!(
377            format!("{}", token).starts_with("Invalid Token expired at"),
378            "expected an invalid token but got {}",
379            token
380        )
381    }
382
383    #[test]
384    fn test_token_is_valid_true() {
385        let token = Token {
386            access_token: "Hello".to_owned(),
387            token_type: "token type".to_owned(),
388            expiry: chrono::Utc::now()
389                + chrono::Duration::microseconds(35 * ONE_SECOND_TO_MICROSECONDS),
390            scope: None,
391        };
392
393        assert!(token.is_valid());
394        assert!(
395            format!("{}", token).starts_with("Valid Token expires at"),
396            "expected a valid token but got {}",
397            token
398        )
399    }
400}