google_cloud_auth/credentials/
user_account.rs

1// Copyright 2024 Google LLC
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     https://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15//! [User Account] Credentials type.
16//!
17//! User accounts represent a developer, administrator, or any other person who
18//! interacts with Google APIs and services. User accounts are managed as
19//! [Google Accounts], either via [Google Workspace] or [Cloud Identity].
20//!
21//! This module provides [Credentials] derived from user account
22//! information, specifically utilizing an OAuth 2.0 refresh token.
23//!
24//! This module is designed for refresh tokens obtained via the standard
25//! [Authorization Code grant]. Acquiring the initial refresh token (e.g., through
26//! user consent) is outside the scope of this library.
27//! See [RFC 6749 Section 4.1] for flow details.
28//!
29//! The Google Cloud client libraries for Rust will typically find and use these
30//! credentials automatically if a credentials file exists in the
31//! standard ADC search paths. This file is often created by running:
32//! `gcloud auth application-default login`. You might instantiate these credentials
33//! directly using the [`Builder`] if you need to:
34//! * Load credentials from a non-standard location or source.
35//! * Override the OAuth 2.0 **scopes** being requested for the access token.
36//! * Override the **quota project ID** for billing and quota management.
37//! * Override the **token URI** used to fetch access tokens.
38//!
39//! Example usage:
40//!
41//! ```
42//! # use google_cloud_auth::credentials::user_account::Builder;
43//! # use google_cloud_auth::credentials::Credentials;
44//! # use google_cloud_auth::errors::CredentialsError;
45//! # tokio_test::block_on(async {
46//! let authorized_user = serde_json::json!({
47//!     "client_id": "YOUR_CLIENT_ID.apps.googleusercontent.com", // Replace with your actual Client ID
48//!     "client_secret": "YOUR_CLIENT_SECRET", // Replace with your actual Client Secret - LOAD SECURELY!
49//!     "refresh_token": "YOUR_REFRESH_TOKEN", // Replace with the user's refresh token - LOAD SECURELY!
50//!     "type": "authorized_user",
51//!     // "quota_project_id": "your-billing-project-id", // Optional: Set if needed
52//!     // "token_uri" : "test-token-uri", // Optional: Set if needed
53//! });
54//! let credentials: Credentials = Builder::new(authorized_user).build()?;
55//! let token = credentials.token().await?;
56//! println!("Token: {}", token.token);
57//! # Ok::<(), CredentialsError>(())
58//! # });
59//! ```
60//!
61//! [Authorization Code grant]: https://tools.ietf.org/html/rfc6749#section-1.3.1
62//! [Cloud Identity]: https://cloud.google.com/identity
63//! [Google Accounts]: https://myaccount.google.com/
64//! [Google Workspace]: https://workspace.google.com/
65//! [RFC 6749 Section 4.1]: https://datatracker.ietf.org/doc/html/rfc6749#section-4.1
66//! [User Account]: https://cloud.google.com/docs/authentication#user-accounts
67//! [Workforce Identity Federation]: https://cloud.google.com/iam/docs/workforce-identity-federation
68
69use crate::credentials::dynamic::CredentialsTrait;
70use crate::credentials::{Credentials, Result};
71use crate::errors::{self, CredentialsError, is_retryable};
72use crate::headers_util::build_bearer_headers;
73use crate::token::{Token, TokenProvider};
74use crate::token_cache::TokenCache;
75use http::header::{CONTENT_TYPE, HeaderName, HeaderValue};
76use reqwest::{Client, Method};
77use serde_json::Value;
78use std::sync::Arc;
79use std::time::Duration;
80
81const OAUTH2_ENDPOINT: &str = "https://oauth2.googleapis.com/token";
82
83pub(crate) fn creds_from(js: Value) -> Result<Credentials> {
84    Builder::new(js).build()
85}
86
87/// A builder for constructing `user_account` [Credentials] instance.
88///
89/// # Example
90/// ```
91/// # use google_cloud_auth::credentials::user_account::Builder;
92/// # tokio_test::block_on(async {
93/// let authorized_user = serde_json::json!({ /* add details here */ });
94/// let credentials = Builder::new(authorized_user).build();
95/// })
96/// ```
97pub struct Builder {
98    authorized_user: Value,
99    scopes: Option<Vec<String>>,
100    quota_project_id: Option<String>,
101    token_uri: Option<String>,
102}
103
104impl Builder {
105    /// Creates a new builder using `authorized_user` JSON value.
106    ///
107    /// The `authorized_user` JSON is typically generated when a user
108    /// authenticates using the [application-default login] process.
109    ///
110    /// [application-default login]: https://cloud.google.com/sdk/gcloud/reference/auth/application-default/login
111    pub fn new(authorized_user: Value) -> Self {
112        Self {
113            authorized_user,
114            scopes: None,
115            quota_project_id: None,
116            token_uri: None,
117        }
118    }
119
120    /// Sets the URI for the token endpoint used to fetch access tokens.
121    ///
122    /// Any value provided here overrides a `token_uri` value from the input `authorized_user` JSON.
123    /// Defaults to `https://oauth2.googleapis.com/token` if not specified here or in the `authorized_user` JSON.
124    ///
125    /// # Example
126    /// ```
127    /// # use google_cloud_auth::credentials::user_account::Builder;
128    /// let authorized_user = serde_json::json!({ /* add details here */ });
129    /// let credentials = Builder::new(authorized_user)
130    ///     .with_token_uri("https://oauth2-FOOBAR.p.googleapis.com")
131    ///     .build();
132    /// ```
133    pub fn with_token_uri<S: Into<String>>(mut self, token_uri: S) -> Self {
134        self.token_uri = Some(token_uri.into());
135        self
136    }
137
138    /// Sets the [scopes] for these credentials.
139    ///
140    /// `scopes` define the *permissions being requested* for this specific access token
141    /// when interacting with a service. For example, `https://www.googleapis.com/auth/devstorage.read_write`.
142    /// IAM permissions, on the other hand, define the *underlying capabilities*
143    /// the user account possesses within a system. For example, `storage.buckets.delete`.
144    /// When a token generated with specific scopes is used, the request must be permitted
145    /// by both the user account's underlying IAM permissions and the scopes requested
146    /// for the token. Therefore, scopes act as an additional restriction on what the token
147    /// can be used for.
148    ///
149    /// # Example
150    /// ```
151    /// # use google_cloud_auth::credentials::user_account::Builder;
152    /// let authorized_user = serde_json::json!({ /* add details here */ });
153    /// let credentials = Builder::new(authorized_user)
154    ///     .with_scopes(["https://www.googleapis.com/auth/pubsub"])
155    ///     .build();
156    /// ```
157    /// [scopes]: https://developers.google.com/identity/protocols/oauth2/scopes
158    pub fn with_scopes<I, S>(mut self, scopes: I) -> Self
159    where
160        I: IntoIterator<Item = S>,
161        S: Into<String>,
162    {
163        self.scopes = Some(scopes.into_iter().map(|s| s.into()).collect());
164        self
165    }
166
167    /// Sets the [quota project] for these credentials.
168    ///
169    /// In some services, you can use an account in
170    /// one project for authentication and authorization, and charge
171    /// the usage to a different project. This requires that the
172    /// user has `serviceusage.services.use` permissions on the quota project.
173    ///
174    /// Any value set here overrides a `quota_project_id` value from the
175    /// input `authorized_user` JSON.
176    ///
177    /// # Example
178    /// ```
179    /// # use google_cloud_auth::credentials::user_account::Builder;
180    /// let authorized_user = serde_json::json!("{ /* add details here */ }");
181    /// let credentials = Builder::new(authorized_user)
182    ///     .with_quota_project_id("my-project")
183    ///     .build();
184    /// ```
185    ///
186    /// [quota project]: https://cloud.google.com/docs/quotas/quota-project
187    pub fn with_quota_project_id<S: Into<String>>(mut self, quota_project_id: S) -> Self {
188        self.quota_project_id = Some(quota_project_id.into());
189        self
190    }
191
192    /// Returns a [Credentials] instance with the configured settings.
193    ///
194    /// # Errors
195    ///
196    /// Returns a [CredentialsError] if the `authorized_user`
197    /// provided to [`Builder::new`] cannot be successfully deserialized into the
198    /// expected format. This typically happens if the JSON value is malformed or
199    /// missing required fields. For more information, on how to generate
200    /// `authorized_user` json, consult the relevant section in the
201    /// [application-default credentials] guide.
202    ///
203    /// [application-default credentials]: https://cloud.google.com/docs/authentication/application-default-credentials
204    pub fn build(self) -> Result<Credentials> {
205        let authorized_user = serde_json::from_value::<AuthorizedUser>(self.authorized_user)
206            .map_err(errors::non_retryable)?;
207        let endpoint = self
208            .token_uri
209            .or(authorized_user.token_uri)
210            .unwrap_or(OAUTH2_ENDPOINT.to_string());
211        let quota_project_id = self.quota_project_id.or(authorized_user.quota_project_id);
212
213        let token_provider = UserTokenProvider {
214            client_id: authorized_user.client_id,
215            client_secret: authorized_user.client_secret,
216            refresh_token: authorized_user.refresh_token,
217            endpoint,
218            scopes: self.scopes.map(|scopes| scopes.join(" ")),
219        };
220        let token_provider = TokenCache::new(token_provider);
221
222        Ok(Credentials {
223            inner: Arc::new(UserCredentials {
224                token_provider,
225                quota_project_id,
226            }),
227        })
228    }
229}
230
231#[derive(PartialEq)]
232struct UserTokenProvider {
233    client_id: String,
234    client_secret: String,
235    refresh_token: String,
236    endpoint: String,
237    scopes: Option<String>,
238}
239
240impl std::fmt::Debug for UserTokenProvider {
241    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
242        f.debug_struct("UserCredentials")
243            .field("client_id", &self.client_id)
244            .field("client_secret", &"[censored]")
245            .field("refresh_token", &"[censored]")
246            .field("endpoint", &self.endpoint)
247            .field("scopes", &self.scopes)
248            .finish()
249    }
250}
251
252#[async_trait::async_trait]
253impl TokenProvider for UserTokenProvider {
254    async fn token(&self) -> Result<Token> {
255        let client = Client::new();
256
257        // Make the request
258        let req = Oauth2RefreshRequest {
259            grant_type: RefreshGrantType::RefreshToken,
260            client_id: self.client_id.clone(),
261            client_secret: self.client_secret.clone(),
262            refresh_token: self.refresh_token.clone(),
263            scopes: self.scopes.clone(),
264        };
265        let header = HeaderValue::from_static("application/json");
266        let builder = client
267            .request(Method::POST, self.endpoint.as_str())
268            .header(CONTENT_TYPE, header)
269            .json(&req);
270        let resp = builder.send().await.map_err(errors::retryable)?;
271
272        // Process the response
273        if !resp.status().is_success() {
274            let status = resp.status();
275            let body = resp
276                .text()
277                .await
278                .map_err(|e| CredentialsError::new(is_retryable(status), e))?;
279            return Err(CredentialsError::from_str(
280                is_retryable(status),
281                format!("Failed to fetch token. {body}"),
282            ));
283        }
284        let response = resp.json::<Oauth2RefreshResponse>().await.map_err(|e| {
285            let retryable = !e.is_decode();
286            CredentialsError::new(retryable, e)
287        })?;
288        let token = Token {
289            token: response.access_token,
290            token_type: response.token_type,
291            expires_at: response
292                .expires_in
293                .map(|d| std::time::Instant::now() + Duration::from_secs(d)),
294            metadata: None,
295        };
296        Ok(token)
297    }
298}
299
300/// Data model for a UserCredentials
301///
302/// See: https://cloud.google.com/docs/authentication#user-accounts
303#[derive(Debug)]
304pub(crate) struct UserCredentials<T>
305where
306    T: TokenProvider,
307{
308    token_provider: T,
309    quota_project_id: Option<String>,
310}
311
312#[async_trait::async_trait]
313impl<T> CredentialsTrait for UserCredentials<T>
314where
315    T: TokenProvider,
316{
317    async fn token(&self) -> Result<Token> {
318        self.token_provider.token().await
319    }
320
321    async fn headers(&self) -> Result<Vec<(HeaderName, HeaderValue)>> {
322        let token = self.token().await?;
323        build_bearer_headers(&token, &self.quota_project_id)
324    }
325}
326
327#[derive(Debug, PartialEq, serde::Deserialize)]
328pub(crate) struct AuthorizedUser {
329    #[serde(rename = "type")]
330    cred_type: String,
331    client_id: String,
332    client_secret: String,
333    refresh_token: String,
334    #[serde(skip_serializing_if = "Option::is_none")]
335    token_uri: Option<String>,
336    #[serde(skip_serializing_if = "Option::is_none")]
337    quota_project_id: Option<String>,
338}
339
340#[derive(Clone, Debug, PartialEq, serde::Deserialize, serde::Serialize)]
341enum RefreshGrantType {
342    #[serde(rename = "refresh_token")]
343    RefreshToken,
344}
345
346#[derive(Clone, Debug, PartialEq, serde::Deserialize, serde::Serialize)]
347struct Oauth2RefreshRequest {
348    grant_type: RefreshGrantType,
349    client_id: String,
350    client_secret: String,
351    refresh_token: String,
352    scopes: Option<String>,
353}
354
355#[derive(Clone, Debug, PartialEq, serde::Deserialize, serde::Serialize)]
356struct Oauth2RefreshResponse {
357    access_token: String,
358    #[serde(skip_serializing_if = "Option::is_none")]
359    scope: Option<String>,
360    #[serde(skip_serializing_if = "Option::is_none")]
361    expires_in: Option<u64>,
362    token_type: String,
363    #[serde(skip_serializing_if = "Option::is_none")]
364    refresh_token: Option<String>,
365}
366
367#[cfg(test)]
368mod test {
369    use super::*;
370    use crate::credentials::QUOTA_PROJECT_KEY;
371    use crate::credentials::test::HV;
372    use crate::token::test::MockTokenProvider;
373    use axum::extract::Json;
374    use http::StatusCode;
375    use http::header::AUTHORIZATION;
376    use std::error::Error;
377    use std::sync::Mutex;
378    use tokio::task::JoinHandle;
379
380    type TestResult = std::result::Result<(), Box<dyn std::error::Error>>;
381
382    #[test]
383    fn debug_token_provider() {
384        let expected = UserTokenProvider {
385            client_id: "test-client-id".to_string(),
386            client_secret: "test-client-secret".to_string(),
387            refresh_token: "test-refresh-token".to_string(),
388            endpoint: OAUTH2_ENDPOINT.to_string(),
389            scopes: Some("https://www.googleapis.com/auth/pubsub".to_string()),
390        };
391        let fmt = format!("{expected:?}");
392        assert!(fmt.contains("test-client-id"), "{fmt}");
393        assert!(!fmt.contains("test-client-secret"), "{fmt}");
394        assert!(!fmt.contains("test-refresh-token"), "{fmt}");
395        assert!(fmt.contains(OAUTH2_ENDPOINT), "{fmt}");
396        assert!(
397            fmt.contains("https://www.googleapis.com/auth/pubsub"),
398            "{fmt}"
399        );
400    }
401
402    #[test]
403    fn authorized_user_full_from_json_success() {
404        let json = serde_json::json!({
405            "account": "",
406            "client_id": "test-client-id",
407            "client_secret": "test-client-secret",
408            "refresh_token": "test-refresh-token",
409            "type": "authorized_user",
410            "universe_domain": "googleapis.com",
411            "quota_project_id": "test-project",
412            "token_uri" : "test-token-uri",
413        });
414
415        let expected = AuthorizedUser {
416            cred_type: "authorized_user".to_string(),
417            client_id: "test-client-id".to_string(),
418            client_secret: "test-client-secret".to_string(),
419            refresh_token: "test-refresh-token".to_string(),
420            quota_project_id: Some("test-project".to_string()),
421            token_uri: Some("test-token-uri".to_string()),
422        };
423        let actual = serde_json::from_value::<AuthorizedUser>(json).unwrap();
424        assert_eq!(actual, expected);
425    }
426
427    #[test]
428    fn authorized_user_partial_from_json_success() {
429        let json = serde_json::json!({
430            "client_id": "test-client-id",
431            "client_secret": "test-client-secret",
432            "refresh_token": "test-refresh-token",
433            "type": "authorized_user",
434        });
435
436        let expected = AuthorizedUser {
437            cred_type: "authorized_user".to_string(),
438            client_id: "test-client-id".to_string(),
439            client_secret: "test-client-secret".to_string(),
440            refresh_token: "test-refresh-token".to_string(),
441            quota_project_id: None,
442            token_uri: None,
443        };
444        let actual = serde_json::from_value::<AuthorizedUser>(json).unwrap();
445        assert_eq!(actual, expected);
446    }
447
448    #[test]
449    fn authorized_user_from_json_parse_fail() {
450        let json_full = serde_json::json!({
451            "client_id": "test-client-id",
452            "client_secret": "test-client-secret",
453            "refresh_token": "test-refresh-token",
454            "type": "authorized_user",
455            "quota_project_id": "test-project"
456        });
457
458        for required_field in ["client_id", "client_secret", "refresh_token"] {
459            let mut json = json_full.clone();
460            // Remove a required field from the JSON
461            json[required_field].take();
462            serde_json::from_value::<AuthorizedUser>(json)
463                .err()
464                .unwrap();
465        }
466    }
467
468    #[tokio::test]
469    async fn token_success() {
470        let expected = Token {
471            token: "test-token".to_string(),
472            token_type: "Bearer".to_string(),
473            expires_at: None,
474            metadata: None,
475        };
476        let expected_clone = expected.clone();
477
478        let mut mock = MockTokenProvider::new();
479        mock.expect_token()
480            .times(1)
481            .return_once(|| Ok(expected_clone));
482
483        let uc = UserCredentials {
484            token_provider: mock,
485            quota_project_id: None,
486        };
487        let actual = uc.token().await.unwrap();
488        assert_eq!(actual, expected);
489    }
490
491    #[tokio::test]
492    async fn token_failure() {
493        let mut mock = MockTokenProvider::new();
494        mock.expect_token()
495            .times(1)
496            .return_once(|| Err(errors::non_retryable_from_str("fail")));
497
498        let uc = UserCredentials {
499            token_provider: mock,
500            quota_project_id: None,
501        };
502        assert!(uc.token().await.is_err());
503    }
504
505    #[tokio::test]
506    async fn headers_success() {
507        let token = Token {
508            token: "test-token".to_string(),
509            token_type: "Bearer".to_string(),
510            expires_at: None,
511            metadata: None,
512        };
513
514        let mut mock = MockTokenProvider::new();
515        mock.expect_token().times(1).return_once(|| Ok(token));
516
517        let uc = UserCredentials {
518            token_provider: mock,
519            quota_project_id: None,
520        };
521        let headers: Vec<HV> = HV::from(uc.headers().await.unwrap());
522
523        assert_eq!(
524            headers,
525            vec![HV {
526                header: AUTHORIZATION.to_string(),
527                value: "Bearer test-token".to_string(),
528                is_sensitive: true,
529            }]
530        );
531    }
532
533    #[tokio::test]
534    async fn headers_failure() {
535        let mut mock = MockTokenProvider::new();
536        mock.expect_token()
537            .times(1)
538            .return_once(|| Err(errors::non_retryable_from_str("fail")));
539
540        let uc = UserCredentials {
541            token_provider: mock,
542            quota_project_id: None,
543        };
544        assert!(uc.headers().await.is_err());
545    }
546
547    #[tokio::test]
548    async fn headers_with_quota_project_success() {
549        let token = Token {
550            token: "test-token".to_string(),
551            token_type: "Bearer".to_string(),
552            expires_at: None,
553            metadata: None,
554        };
555
556        let mut mock = MockTokenProvider::new();
557        mock.expect_token().times(1).return_once(|| Ok(token));
558
559        let uc = UserCredentials {
560            token_provider: mock,
561            quota_project_id: Some("test-project".to_string()),
562        };
563        let headers: Vec<HV> = HV::from(uc.headers().await.unwrap());
564        assert_eq!(
565            headers,
566            vec![
567                HV {
568                    header: AUTHORIZATION.to_string(),
569                    value: "Bearer test-token".to_string(),
570                    is_sensitive: true,
571                },
572                HV {
573                    header: QUOTA_PROJECT_KEY.to_string(),
574                    value: "test-project".to_string(),
575                    is_sensitive: false,
576                }
577            ]
578        );
579    }
580
581    #[test]
582    fn oauth2_request_serde() {
583        let request = Oauth2RefreshRequest {
584            grant_type: RefreshGrantType::RefreshToken,
585            client_id: "test-client-id".to_string(),
586            client_secret: "test-client-secret".to_string(),
587            refresh_token: "test-refresh-token".to_string(),
588            scopes: Some("scope1 scope2".to_string()),
589        };
590
591        let json = serde_json::to_value(&request).unwrap();
592        let expected = serde_json::json!({
593            "grant_type": "refresh_token",
594            "client_id": "test-client-id",
595            "client_secret": "test-client-secret",
596            "refresh_token": "test-refresh-token",
597            "scopes": "scope1 scope2",
598        });
599        assert_eq!(json, expected);
600        let roundtrip = serde_json::from_value::<Oauth2RefreshRequest>(json).unwrap();
601        assert_eq!(request, roundtrip);
602    }
603
604    #[test]
605    fn oauth2_response_serde_full() {
606        let response = Oauth2RefreshResponse {
607            access_token: "test-access-token".to_string(),
608            scope: Some("scope1 scope2".to_string()),
609            expires_in: Some(3600),
610            token_type: "test-token-type".to_string(),
611            refresh_token: Some("test-refresh-token".to_string()),
612        };
613
614        let json = serde_json::to_value(&response).unwrap();
615        let expected = serde_json::json!({
616            "access_token": "test-access-token",
617            "scope": "scope1 scope2",
618            "expires_in": 3600,
619            "token_type": "test-token-type",
620            "refresh_token": "test-refresh-token"
621        });
622        assert_eq!(json, expected);
623        let roundtrip = serde_json::from_value::<Oauth2RefreshResponse>(json).unwrap();
624        assert_eq!(response, roundtrip);
625    }
626
627    #[test]
628    fn oauth2_response_serde_partial() {
629        let response = Oauth2RefreshResponse {
630            access_token: "test-access-token".to_string(),
631            scope: None,
632            expires_in: None,
633            token_type: "test-token-type".to_string(),
634            refresh_token: None,
635        };
636
637        let json = serde_json::to_value(&response).unwrap();
638        let expected = serde_json::json!({
639            "access_token": "test-access-token",
640            "token_type": "test-token-type",
641        });
642        assert_eq!(json, expected);
643        let roundtrip = serde_json::from_value::<Oauth2RefreshResponse>(json).unwrap();
644        assert_eq!(response, roundtrip);
645    }
646
647    // Starts a server running locally. Returns an (endpoint, handler) pair.
648    async fn start(
649        response_code: StatusCode,
650        response_body: Value,
651        call_count: Arc<Mutex<i32>>,
652    ) -> (String, JoinHandle<()>) {
653        let code = response_code;
654        let body = response_body.clone();
655        let handler = move |req| async move { handle_token_factory(code, body, call_count)(req) };
656        let app = axum::Router::new().route("/token", axum::routing::post(handler));
657        let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
658        let addr = listener.local_addr().unwrap();
659        let server = tokio::spawn(async {
660            axum::serve(listener, app).await.unwrap();
661        });
662
663        (
664            format!("http://{}:{}/token", addr.ip(), addr.port()),
665            server,
666        )
667    }
668
669    // Creates a handler that
670    // - verifies fields in an Oauth2RefreshRequest
671    // - returns a pre-canned HTTP response
672    fn handle_token_factory(
673        response_code: StatusCode,
674        response_body: Value,
675        call_count: Arc<std::sync::Mutex<i32>>,
676    ) -> impl Fn(Json<Oauth2RefreshRequest>) -> (StatusCode, String) {
677        move |request: Json<Oauth2RefreshRequest>| -> (StatusCode, String) {
678            let mut count = call_count.lock().unwrap();
679            *count += 1;
680            assert_eq!(request.client_id, "test-client-id");
681            assert_eq!(request.client_secret, "test-client-secret");
682            assert_eq!(request.refresh_token, "test-refresh-token");
683            assert_eq!(request.grant_type, RefreshGrantType::RefreshToken);
684            assert_eq!(
685                request.scopes,
686                response_body["scope"].as_str().map(|s| s.to_owned())
687            );
688
689            (response_code, response_body.to_string())
690        }
691    }
692
693    #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
694    async fn token_provider_full() -> TestResult {
695        let response = Oauth2RefreshResponse {
696            access_token: "test-access-token".to_string(),
697            expires_in: Some(3600),
698            refresh_token: Some("test-refresh-token".to_string()),
699            scope: Some("scope1 scope2".to_string()),
700            token_type: "test-token-type".to_string(),
701        };
702        let response_body = serde_json::to_value(&response).unwrap();
703        let (endpoint, _server) =
704            start(StatusCode::OK, response_body, Arc::new(Mutex::new(0))).await;
705        println!("endpoint = {endpoint}");
706
707        let authorized_user = serde_json::json!({
708            "client_id": "test-client-id",
709            "client_secret": "test-client-secret",
710            "refresh_token": "test-refresh-token",
711            "type": "authorized_user",
712            "token_uri": endpoint,
713        });
714        let cred = Builder::new(authorized_user)
715            .with_scopes(vec!["scope1", "scope2"])
716            .build()?;
717
718        let now = std::time::Instant::now();
719        let token = cred.token().await?;
720        assert_eq!(token.token, "test-access-token");
721        assert_eq!(token.token_type, "test-token-type");
722        assert!(
723            token
724                .expires_at
725                .is_some_and(|d| d >= now + Duration::from_secs(3600)),
726            "now: {:?}, expires_at: {:?}",
727            now,
728            token.expires_at
729        );
730
731        Ok(())
732    }
733
734    #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
735    async fn token_provider_full_with_quota_project() -> TestResult {
736        let response = Oauth2RefreshResponse {
737            access_token: "test-access-token".to_string(),
738            expires_in: Some(3600),
739            refresh_token: Some("test-refresh-token".to_string()),
740            scope: None,
741            token_type: "test-token-type".to_string(),
742        };
743        let response_body = serde_json::to_value(&response).unwrap();
744        let (endpoint, _server) =
745            start(StatusCode::OK, response_body, Arc::new(Mutex::new(0))).await;
746        println!("endpoint = {endpoint}");
747
748        let authorized_user = serde_json::json!({
749            "client_id": "test-client-id",
750            "client_secret": "test-client-secret",
751            "refresh_token": "test-refresh-token",
752            "type": "authorized_user",
753            "token_uri": endpoint,
754        });
755        let cred = Builder::new(authorized_user)
756            .with_quota_project_id("test-project")
757            .build()?;
758
759        let headers: Vec<HV> = HV::from(cred.headers().await.unwrap());
760        assert_eq!(
761            headers,
762            vec![
763                HV {
764                    header: AUTHORIZATION.to_string(),
765                    value: "test-token-type test-access-token".to_string(),
766                    is_sensitive: true,
767                },
768                HV {
769                    header: QUOTA_PROJECT_KEY.to_string(),
770                    value: "test-project".to_string(),
771                    is_sensitive: false,
772                }
773            ]
774        );
775        Ok(())
776    }
777
778    #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
779    async fn creds_from_json_custom_uri_with_caching() -> TestResult {
780        let response = Oauth2RefreshResponse {
781            access_token: "test-access-token".to_string(),
782            expires_in: Some(3600),
783            refresh_token: Some("test-refresh-token".to_string()),
784            scope: Some("scope1 scope2".to_string()),
785            token_type: "test-token-type".to_string(),
786        };
787        let response_body = serde_json::to_value(&response).unwrap();
788        let call_count = Arc::new(Mutex::new(0));
789        let (endpoint, _server) = start(StatusCode::OK, response_body, call_count.clone()).await;
790        println!("endpoint = {endpoint}");
791
792        let json = serde_json::json!({
793            "client_id": "test-client-id",
794            "client_secret": "test-client-secret",
795            "refresh_token": "test-refresh-token",
796            "type": "authorized_user",
797            "universe_domain": "googleapis.com",
798            "quota_project_id": "test-project",
799            "token_uri": endpoint,
800        });
801
802        let cred = Builder::new(json)
803            .with_scopes(vec!["scope1", "scope2"])
804            .build()?;
805
806        let token = cred.token().await?;
807        assert_eq!(token.token, "test-access-token");
808
809        let token = cred.token().await?;
810        assert_eq!(token.token, "test-access-token");
811
812        // Test that the inner token provider was called only
813        // once even though token was called twice.
814        assert_eq!(*call_count.lock().unwrap(), 1);
815
816        Ok(())
817    }
818
819    #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
820    async fn token_provider_partial() -> TestResult {
821        let response = Oauth2RefreshResponse {
822            access_token: "test-access-token".to_string(),
823            expires_in: None,
824            refresh_token: None,
825            scope: None,
826            token_type: "test-token-type".to_string(),
827        };
828        let response_body = serde_json::to_value(&response).unwrap();
829        let (endpoint, _server) =
830            start(StatusCode::OK, response_body, Arc::new(Mutex::new(0))).await;
831        println!("endpoint = {endpoint}");
832
833        let authorized_user = serde_json::json!({
834            "client_id": "test-client-id",
835            "client_secret": "test-client-secret",
836            "refresh_token": "test-refresh-token",
837            "type": "authorized_user",
838            "token_uri": endpoint});
839
840        let uc = Builder::new(authorized_user).build()?;
841        let token = uc.token().await?;
842        assert_eq!(token.token, "test-access-token");
843        assert_eq!(token.token_type, "test-token-type");
844        assert_eq!(token.expires_at, None);
845
846        Ok(())
847    }
848
849    #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
850    async fn token_provider_with_token_uri() -> TestResult {
851        let response = Oauth2RefreshResponse {
852            access_token: "test-access-token".to_string(),
853            expires_in: None,
854            refresh_token: None,
855            scope: None,
856            token_type: "test-token-type".to_string(),
857        };
858        let response_body = serde_json::to_value(&response).unwrap();
859        let (endpoint, _server) =
860            start(StatusCode::OK, response_body, Arc::new(Mutex::new(0))).await;
861        println!("endpoint = {endpoint}");
862
863        let authorized_user = serde_json::json!({
864            "client_id": "test-client-id",
865            "client_secret": "test-client-secret",
866            "refresh_token": "test-refresh-token",
867            "type": "authorized_user",
868            "token_uri": "test-endpoint"});
869
870        let uc = Builder::new(authorized_user)
871            .with_token_uri(endpoint)
872            .build()?;
873        let token = uc.token().await?;
874        assert_eq!(token.token, "test-access-token");
875        assert_eq!(token.token_type, "test-token-type");
876        assert_eq!(token.expires_at, None);
877
878        Ok(())
879    }
880
881    #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
882    async fn token_provider_with_scopes() -> TestResult {
883        let response = Oauth2RefreshResponse {
884            access_token: "test-access-token".to_string(),
885            expires_in: None,
886            refresh_token: None,
887            scope: Some("scope1 scope2".to_string()),
888            token_type: "test-token-type".to_string(),
889        };
890        let response_body = serde_json::to_value(&response).unwrap();
891        let (endpoint, _server) =
892            start(StatusCode::OK, response_body, Arc::new(Mutex::new(0))).await;
893        println!("endpoint = {endpoint}");
894
895        let authorized_user = serde_json::json!({
896            "client_id": "test-client-id",
897            "client_secret": "test-client-secret",
898            "refresh_token": "test-refresh-token",
899            "type": "authorized_user",
900            "token_uri": "test-endpoint"});
901
902        let uc = Builder::new(authorized_user)
903            .with_token_uri(endpoint)
904            .with_scopes(vec!["scope1", "scope2"])
905            .build()?;
906        let token = uc.token().await?;
907        assert_eq!(token.token, "test-access-token");
908        assert_eq!(token.token_type, "test-token-type");
909        assert_eq!(token.expires_at, None);
910
911        Ok(())
912    }
913
914    #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
915    async fn token_provider_retryable_error() -> TestResult {
916        let (endpoint, _server) = start(
917            StatusCode::SERVICE_UNAVAILABLE,
918            serde_json::to_value("try again".to_string())?,
919            Arc::new(Mutex::new(0)),
920        )
921        .await;
922        println!("endpoint = {endpoint}");
923
924        let authorized_user = serde_json::json!({
925            "client_id": "test-client-id",
926            "client_secret": "test-client-secret",
927            "refresh_token": "test-refresh-token",
928            "type": "authorized_user",
929            "token_uri": endpoint});
930
931        let uc = Builder::new(authorized_user).build()?;
932        let e = uc.token().await.err().unwrap();
933        assert!(e.is_retryable(), "{e}");
934        assert!(e.source().unwrap().to_string().contains("try again"), "{e}");
935
936        Ok(())
937    }
938
939    #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
940    async fn token_provider_nonretryable_error() -> TestResult {
941        let (endpoint, _server) = start(
942            StatusCode::UNAUTHORIZED,
943            serde_json::to_value("epic fail".to_string())?,
944            Arc::new(Mutex::new(0)),
945        )
946        .await;
947        println!("endpoint = {endpoint}");
948
949        let authorized_user = serde_json::json!({
950            "client_id": "test-client-id",
951            "client_secret": "test-client-secret",
952            "refresh_token": "test-refresh-token",
953            "type": "authorized_user",
954            "token_uri": endpoint});
955
956        let uc = Builder::new(authorized_user).build()?;
957        let e = uc.token().await.err().unwrap();
958        assert!(!e.is_retryable(), "{e}");
959        assert!(e.source().unwrap().to_string().contains("epic fail"), "{e}");
960
961        Ok(())
962    }
963
964    #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
965    async fn token_provider_malformed_response_is_nonretryable() -> TestResult {
966        let (endpoint, _server) = start(
967            StatusCode::OK,
968            serde_json::to_value("bad json".to_string())?,
969            Arc::new(Mutex::new(0)),
970        )
971        .await;
972        println!("endpoint = {endpoint}");
973
974        let authorized_user = serde_json::json!({
975            "client_id": "test-client-id",
976            "client_secret": "test-client-secret",
977            "refresh_token": "test-refresh-token",
978            "type": "authorized_user",
979            "token_uri": endpoint});
980
981        let uc = Builder::new(authorized_user).build()?;
982        let e = uc.token().await.err().unwrap();
983        assert!(!e.is_retryable(), "{e}");
984
985        Ok(())
986    }
987
988    #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
989    async fn builder_malformed_authorized_json_nonretryable() -> TestResult {
990        let authorized_user = serde_json::json!({
991        "client_secret": "test-client-secret",
992        "refresh_token": "test-refresh-token",
993        "type": "authorized_user",
994        });
995
996        let e = Builder::new(authorized_user).build().unwrap_err();
997        assert!(!e.is_retryable(), "{e}");
998
999        Ok(())
1000    }
1001}