google_cloud_auth/credentials/
user_account.rs

1// Copyright 2024 Google LLC
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     https://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15//! [User Account] Credentials type.
16//!
17//! User accounts represent a developer, administrator, or any other person who
18//! interacts with Google APIs and services. User accounts are managed as
19//! [Google Accounts], either via [Google Workspace] or [Cloud Identity].
20//!
21//! This module provides [Credentials] derived from user account
22//! information, specifically utilizing an OAuth 2.0 refresh token.
23//!
24//! This module is designed for refresh tokens obtained via the standard
25//! [Authorization Code grant]. Acquiring the initial refresh token (e.g., through
26//! user consent) is outside the scope of this library.
27//! See [RFC 6749 Section 4.1] for flow details.
28//!
29//! The Google Cloud client libraries for Rust will typically find and use these
30//! credentials automatically if a credentials file exists in the
31//! standard ADC search paths. This file is often created by running:
32//! `gcloud auth application-default login`. You might instantiate these credentials
33//! directly using the [`Builder`] if you need to:
34//! * Load credentials from a non-standard location or source.
35//! * Override the OAuth 2.0 **scopes** being requested for the access token.
36//! * Override the **quota project ID** for billing and quota management.
37//! * Override the **token URI** used to fetch access tokens.
38//!
39//! Example usage:
40//!
41//! ```
42//! # use google_cloud_auth::credentials::user_account::Builder;
43//! # use google_cloud_auth::credentials::Credentials;
44//! # use http::Extensions;
45//! # tokio_test::block_on(async {
46//! let authorized_user = serde_json::json!({
47//!     "client_id": "YOUR_CLIENT_ID.apps.googleusercontent.com", // Replace with your actual Client ID
48//!     "client_secret": "YOUR_CLIENT_SECRET", // Replace with your actual Client Secret - LOAD SECURELY!
49//!     "refresh_token": "YOUR_REFRESH_TOKEN", // Replace with the user's refresh token - LOAD SECURELY!
50//!     "type": "authorized_user",
51//!     // "quota_project_id": "your-billing-project-id", // Optional: Set if needed
52//!     // "token_uri" : "test-token-uri", // Optional: Set if needed
53//! });
54//! let credentials: Credentials = Builder::new(authorized_user).build()?;
55//! let headers = credentials.headers(Extensions::new()).await?;
56//! println!("Headers: {headers:?}");
57//! # Ok::<(), anyhow::Error>(())
58//! # });
59//! ```
60//!
61//! [Authorization Code grant]: https://tools.ietf.org/html/rfc6749#section-1.3.1
62//! [Cloud Identity]: https://cloud.google.com/identity
63//! [Google Accounts]: https://myaccount.google.com/
64//! [Google Workspace]: https://workspace.google.com/
65//! [RFC 6749 Section 4.1]: https://datatracker.ietf.org/doc/html/rfc6749#section-4.1
66//! [User Account]: https://cloud.google.com/docs/authentication#user-accounts
67//! [Workforce Identity Federation]: https://cloud.google.com/iam/docs/workforce-identity-federation
68
69use crate::build_errors::Error as BuilderError;
70use crate::credentials::dynamic::CredentialsProvider;
71use crate::credentials::{CacheableResource, Credentials};
72use crate::errors::{self, CredentialsError};
73use crate::headers_util::build_cacheable_headers;
74use crate::token::{CachedTokenProvider, Token, TokenProvider};
75use crate::token_cache::TokenCache;
76use crate::{BuildResult, Result};
77use http::header::CONTENT_TYPE;
78use http::{Extensions, HeaderMap, HeaderValue};
79use reqwest::{Client, Method};
80use serde_json::Value;
81use std::sync::Arc;
82use tokio::time::{Duration, Instant};
83
84const OAUTH2_ENDPOINT: &str = "https://oauth2.googleapis.com/token";
85
86/// A builder for constructing `user_account` [Credentials] instance.
87///
88/// # Example
89/// ```
90/// # use google_cloud_auth::credentials::user_account::Builder;
91/// # tokio_test::block_on(async {
92/// let authorized_user = serde_json::json!({ /* add details here */ });
93/// let credentials = Builder::new(authorized_user).build();
94/// })
95/// ```
96pub struct Builder {
97    authorized_user: Value,
98    scopes: Option<Vec<String>>,
99    quota_project_id: Option<String>,
100    token_uri: Option<String>,
101}
102
103impl Builder {
104    /// Creates a new builder using `authorized_user` JSON value.
105    ///
106    /// The `authorized_user` JSON is typically generated when a user
107    /// authenticates using the [application-default login] process.
108    ///
109    /// [application-default login]: https://cloud.google.com/sdk/gcloud/reference/auth/application-default/login
110    pub fn new(authorized_user: Value) -> Self {
111        Self {
112            authorized_user,
113            scopes: None,
114            quota_project_id: None,
115            token_uri: None,
116        }
117    }
118
119    /// Sets the URI for the token endpoint used to fetch access tokens.
120    ///
121    /// Any value provided here overrides a `token_uri` value from the input `authorized_user` JSON.
122    /// Defaults to `https://oauth2.googleapis.com/token` if not specified here or in the `authorized_user` JSON.
123    ///
124    /// # Example
125    /// ```
126    /// # use google_cloud_auth::credentials::user_account::Builder;
127    /// let authorized_user = serde_json::json!({ /* add details here */ });
128    /// let credentials = Builder::new(authorized_user)
129    ///     .with_token_uri("https://oauth2-FOOBAR.p.googleapis.com")
130    ///     .build();
131    /// ```
132    pub fn with_token_uri<S: Into<String>>(mut self, token_uri: S) -> Self {
133        self.token_uri = Some(token_uri.into());
134        self
135    }
136
137    /// Sets the [scopes] for these credentials.
138    ///
139    /// `scopes` define the *permissions being requested* for this specific access token
140    /// when interacting with a service. For example, `https://www.googleapis.com/auth/devstorage.read_write`.
141    /// IAM permissions, on the other hand, define the *underlying capabilities*
142    /// the user account possesses within a system. For example, `storage.buckets.delete`.
143    /// When a token generated with specific scopes is used, the request must be permitted
144    /// by both the user account's underlying IAM permissions and the scopes requested
145    /// for the token. Therefore, scopes act as an additional restriction on what the token
146    /// can be used for.
147    ///
148    /// # Example
149    /// ```
150    /// # use google_cloud_auth::credentials::user_account::Builder;
151    /// let authorized_user = serde_json::json!({ /* add details here */ });
152    /// let credentials = Builder::new(authorized_user)
153    ///     .with_scopes(["https://www.googleapis.com/auth/pubsub"])
154    ///     .build();
155    /// ```
156    /// [scopes]: https://developers.google.com/identity/protocols/oauth2/scopes
157    pub fn with_scopes<I, S>(mut self, scopes: I) -> Self
158    where
159        I: IntoIterator<Item = S>,
160        S: Into<String>,
161    {
162        self.scopes = Some(scopes.into_iter().map(|s| s.into()).collect());
163        self
164    }
165
166    /// Sets the [quota project] for these credentials.
167    ///
168    /// In some services, you can use an account in
169    /// one project for authentication and authorization, and charge
170    /// the usage to a different project. This requires that the
171    /// user has `serviceusage.services.use` permissions on the quota project.
172    ///
173    /// Any value set here overrides a `quota_project_id` value from the
174    /// input `authorized_user` JSON.
175    ///
176    /// # Example
177    /// ```
178    /// # use google_cloud_auth::credentials::user_account::Builder;
179    /// let authorized_user = serde_json::json!("{ /* add details here */ }");
180    /// let credentials = Builder::new(authorized_user)
181    ///     .with_quota_project_id("my-project")
182    ///     .build();
183    /// ```
184    ///
185    /// [quota project]: https://cloud.google.com/docs/quotas/quota-project
186    pub fn with_quota_project_id<S: Into<String>>(mut self, quota_project_id: S) -> Self {
187        self.quota_project_id = Some(quota_project_id.into());
188        self
189    }
190
191    /// Returns a [Credentials] instance with the configured settings.
192    ///
193    /// # Errors
194    ///
195    /// Returns a [CredentialsError] if the `authorized_user`
196    /// provided to [`Builder::new`] cannot be successfully deserialized into the
197    /// expected format. This typically happens if the JSON value is malformed or
198    /// missing required fields. For more information, on how to generate
199    /// `authorized_user` json, consult the relevant section in the
200    /// [application-default credentials] guide.
201    ///
202    /// [application-default credentials]: https://cloud.google.com/docs/authentication/application-default-credentials
203    pub fn build(self) -> BuildResult<Credentials> {
204        let authorized_user = serde_json::from_value::<AuthorizedUser>(self.authorized_user)
205            .map_err(BuilderError::parsing)?;
206        let endpoint = self
207            .token_uri
208            .or(authorized_user.token_uri)
209            .unwrap_or(OAUTH2_ENDPOINT.to_string());
210        let quota_project_id = self.quota_project_id.or(authorized_user.quota_project_id);
211
212        let token_provider = UserTokenProvider {
213            client_id: authorized_user.client_id,
214            client_secret: authorized_user.client_secret,
215            refresh_token: authorized_user.refresh_token,
216            endpoint,
217            scopes: self.scopes.map(|scopes| scopes.join(" ")),
218        };
219        let token_provider = TokenCache::new(token_provider);
220
221        Ok(Credentials {
222            inner: Arc::new(UserCredentials {
223                token_provider,
224                quota_project_id,
225            }),
226        })
227    }
228}
229
230#[derive(PartialEq)]
231struct UserTokenProvider {
232    client_id: String,
233    client_secret: String,
234    refresh_token: String,
235    endpoint: String,
236    scopes: Option<String>,
237}
238
239impl std::fmt::Debug for UserTokenProvider {
240    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
241        f.debug_struct("UserCredentials")
242            .field("client_id", &self.client_id)
243            .field("client_secret", &"[censored]")
244            .field("refresh_token", &"[censored]")
245            .field("endpoint", &self.endpoint)
246            .field("scopes", &self.scopes)
247            .finish()
248    }
249}
250
251#[async_trait::async_trait]
252impl TokenProvider for UserTokenProvider {
253    async fn token(&self) -> Result<Token> {
254        let client = Client::new();
255
256        // Make the request
257        let req = Oauth2RefreshRequest {
258            grant_type: RefreshGrantType::RefreshToken,
259            client_id: self.client_id.clone(),
260            client_secret: self.client_secret.clone(),
261            refresh_token: self.refresh_token.clone(),
262            scopes: self.scopes.clone(),
263        };
264        let header = HeaderValue::from_static("application/json");
265        let builder = client
266            .request(Method::POST, self.endpoint.as_str())
267            .header(CONTENT_TYPE, header)
268            .json(&req);
269        let resp = builder
270            .send()
271            .await
272            .map_err(|e| errors::from_http_error(e, MSG))?;
273
274        // Process the response
275        if !resp.status().is_success() {
276            let err = errors::from_http_response(resp, MSG).await;
277            return Err(err);
278        }
279        let response = resp.json::<Oauth2RefreshResponse>().await.map_err(|e| {
280            let retryable = !e.is_decode();
281            CredentialsError::from_source(retryable, e)
282        })?;
283        let token = Token {
284            token: response.access_token,
285            token_type: response.token_type,
286            expires_at: response
287                .expires_in
288                .map(|d| Instant::now() + Duration::from_secs(d)),
289            metadata: None,
290        };
291        Ok(token)
292    }
293}
294
295const MSG: &str = "failed to refresh user access token";
296
297/// Data model for a UserCredentials
298///
299/// See: https://cloud.google.com/docs/authentication#user-accounts
300#[derive(Debug)]
301pub(crate) struct UserCredentials<T>
302where
303    T: CachedTokenProvider,
304{
305    token_provider: T,
306    quota_project_id: Option<String>,
307}
308
309#[async_trait::async_trait]
310impl<T> CredentialsProvider for UserCredentials<T>
311where
312    T: CachedTokenProvider,
313{
314    async fn headers(&self, extensions: Extensions) -> Result<CacheableResource<HeaderMap>> {
315        let token = self.token_provider.token(extensions).await?;
316        build_cacheable_headers(&token, &self.quota_project_id)
317    }
318}
319
320#[derive(Debug, PartialEq, serde::Deserialize)]
321pub(crate) struct AuthorizedUser {
322    #[serde(rename = "type")]
323    cred_type: String,
324    client_id: String,
325    client_secret: String,
326    refresh_token: String,
327    #[serde(skip_serializing_if = "Option::is_none")]
328    token_uri: Option<String>,
329    #[serde(skip_serializing_if = "Option::is_none")]
330    quota_project_id: Option<String>,
331}
332
333#[derive(Clone, Debug, PartialEq, serde::Deserialize, serde::Serialize)]
334enum RefreshGrantType {
335    #[serde(rename = "refresh_token")]
336    RefreshToken,
337}
338
339#[derive(Clone, Debug, PartialEq, serde::Deserialize, serde::Serialize)]
340struct Oauth2RefreshRequest {
341    grant_type: RefreshGrantType,
342    client_id: String,
343    client_secret: String,
344    refresh_token: String,
345    scopes: Option<String>,
346}
347
348#[derive(Clone, Debug, PartialEq, serde::Deserialize, serde::Serialize)]
349struct Oauth2RefreshResponse {
350    access_token: String,
351    #[serde(skip_serializing_if = "Option::is_none")]
352    scope: Option<String>,
353    #[serde(skip_serializing_if = "Option::is_none")]
354    expires_in: Option<u64>,
355    token_type: String,
356    #[serde(skip_serializing_if = "Option::is_none")]
357    refresh_token: Option<String>,
358}
359
360#[cfg(test)]
361mod test {
362    use super::*;
363    use crate::credentials::test::{
364        get_headers_from_cache, get_token_from_headers, get_token_type_from_headers,
365    };
366    use crate::credentials::{DEFAULT_UNIVERSE_DOMAIN, QUOTA_PROJECT_KEY};
367    use crate::token::test::MockTokenProvider;
368    use axum::extract::Json;
369    use http::StatusCode;
370    use http::header::AUTHORIZATION;
371    use std::error::Error;
372    use std::sync::Mutex;
373    use tokio::task::JoinHandle;
374
375    type TestResult = std::result::Result<(), Box<dyn std::error::Error>>;
376
377    #[test]
378    fn debug_token_provider() {
379        let expected = UserTokenProvider {
380            client_id: "test-client-id".to_string(),
381            client_secret: "test-client-secret".to_string(),
382            refresh_token: "test-refresh-token".to_string(),
383            endpoint: OAUTH2_ENDPOINT.to_string(),
384            scopes: Some("https://www.googleapis.com/auth/pubsub".to_string()),
385        };
386        let fmt = format!("{expected:?}");
387        assert!(fmt.contains("test-client-id"), "{fmt}");
388        assert!(!fmt.contains("test-client-secret"), "{fmt}");
389        assert!(!fmt.contains("test-refresh-token"), "{fmt}");
390        assert!(fmt.contains(OAUTH2_ENDPOINT), "{fmt}");
391        assert!(
392            fmt.contains("https://www.googleapis.com/auth/pubsub"),
393            "{fmt}"
394        );
395    }
396
397    #[test]
398    fn authorized_user_full_from_json_success() {
399        let json = serde_json::json!({
400            "account": "",
401            "client_id": "test-client-id",
402            "client_secret": "test-client-secret",
403            "refresh_token": "test-refresh-token",
404            "type": "authorized_user",
405            "universe_domain": "googleapis.com",
406            "quota_project_id": "test-project",
407            "token_uri" : "test-token-uri",
408        });
409
410        let expected = AuthorizedUser {
411            cred_type: "authorized_user".to_string(),
412            client_id: "test-client-id".to_string(),
413            client_secret: "test-client-secret".to_string(),
414            refresh_token: "test-refresh-token".to_string(),
415            quota_project_id: Some("test-project".to_string()),
416            token_uri: Some("test-token-uri".to_string()),
417        };
418        let actual = serde_json::from_value::<AuthorizedUser>(json).unwrap();
419        assert_eq!(actual, expected);
420    }
421
422    #[test]
423    fn authorized_user_partial_from_json_success() {
424        let json = serde_json::json!({
425            "client_id": "test-client-id",
426            "client_secret": "test-client-secret",
427            "refresh_token": "test-refresh-token",
428            "type": "authorized_user",
429        });
430
431        let expected = AuthorizedUser {
432            cred_type: "authorized_user".to_string(),
433            client_id: "test-client-id".to_string(),
434            client_secret: "test-client-secret".to_string(),
435            refresh_token: "test-refresh-token".to_string(),
436            quota_project_id: None,
437            token_uri: None,
438        };
439        let actual = serde_json::from_value::<AuthorizedUser>(json).unwrap();
440        assert_eq!(actual, expected);
441    }
442
443    #[test]
444    fn authorized_user_from_json_parse_fail() {
445        let json_full = serde_json::json!({
446            "client_id": "test-client-id",
447            "client_secret": "test-client-secret",
448            "refresh_token": "test-refresh-token",
449            "type": "authorized_user",
450            "quota_project_id": "test-project"
451        });
452
453        for required_field in ["client_id", "client_secret", "refresh_token"] {
454            let mut json = json_full.clone();
455            // Remove a required field from the JSON
456            json[required_field].take();
457            serde_json::from_value::<AuthorizedUser>(json)
458                .err()
459                .unwrap();
460        }
461    }
462
463    #[tokio::test]
464    async fn default_universe_domain_success() {
465        let mock = TokenCache::new(MockTokenProvider::new());
466
467        let uc = UserCredentials {
468            token_provider: mock,
469            quota_project_id: None,
470        };
471        assert_eq!(uc.universe_domain().await.unwrap(), DEFAULT_UNIVERSE_DOMAIN);
472    }
473
474    #[tokio::test]
475    async fn headers_success() -> TestResult {
476        let token = Token {
477            token: "test-token".to_string(),
478            token_type: "Bearer".to_string(),
479            expires_at: None,
480            metadata: None,
481        };
482
483        let mut mock = MockTokenProvider::new();
484        mock.expect_token().times(1).return_once(|| Ok(token));
485
486        let uc = UserCredentials {
487            token_provider: TokenCache::new(mock),
488            quota_project_id: None,
489        };
490
491        let mut extensions = Extensions::new();
492        let cached_headers = uc.headers(extensions.clone()).await.unwrap();
493        let (headers, entity_tag) = match cached_headers {
494            CacheableResource::New { entity_tag, data } => (data, entity_tag),
495            CacheableResource::NotModified => unreachable!("expecting new headers"),
496        };
497        let token = headers.get(AUTHORIZATION).unwrap();
498
499        assert_eq!(headers.len(), 1, "{headers:?}");
500        assert_eq!(token, HeaderValue::from_static("Bearer test-token"));
501        assert!(token.is_sensitive());
502
503        extensions.insert(entity_tag);
504
505        let cached_headers = uc.headers(extensions).await?;
506
507        match cached_headers {
508            CacheableResource::New { .. } => unreachable!("expecting new headers"),
509            CacheableResource::NotModified => CacheableResource::<HeaderMap>::NotModified,
510        };
511        Ok(())
512    }
513
514    #[tokio::test]
515    async fn headers_failure() {
516        let mut mock = MockTokenProvider::new();
517        mock.expect_token()
518            .times(1)
519            .return_once(|| Err(errors::non_retryable_from_str("fail")));
520
521        let uc = UserCredentials {
522            token_provider: TokenCache::new(mock),
523            quota_project_id: None,
524        };
525        assert!(uc.headers(Extensions::new()).await.is_err());
526    }
527
528    #[tokio::test]
529    async fn headers_with_quota_project_success() -> TestResult {
530        let token = Token {
531            token: "test-token".to_string(),
532            token_type: "Bearer".to_string(),
533            expires_at: None,
534            metadata: None,
535        };
536
537        let mut mock = MockTokenProvider::new();
538        mock.expect_token().times(1).return_once(|| Ok(token));
539
540        let uc = UserCredentials {
541            token_provider: TokenCache::new(mock),
542            quota_project_id: Some("test-project".to_string()),
543        };
544
545        let headers = get_headers_from_cache(uc.headers(Extensions::new()).await.unwrap())?;
546        let token = headers.get(AUTHORIZATION).unwrap();
547        let quota_project_header = headers.get(QUOTA_PROJECT_KEY).unwrap();
548
549        assert_eq!(headers.len(), 2, "{headers:?}");
550        assert_eq!(token, HeaderValue::from_static("Bearer test-token"));
551        assert!(token.is_sensitive());
552        assert_eq!(
553            quota_project_header,
554            HeaderValue::from_static("test-project")
555        );
556        assert!(!quota_project_header.is_sensitive());
557        Ok(())
558    }
559
560    #[test]
561    fn oauth2_request_serde() {
562        let request = Oauth2RefreshRequest {
563            grant_type: RefreshGrantType::RefreshToken,
564            client_id: "test-client-id".to_string(),
565            client_secret: "test-client-secret".to_string(),
566            refresh_token: "test-refresh-token".to_string(),
567            scopes: Some("scope1 scope2".to_string()),
568        };
569
570        let json = serde_json::to_value(&request).unwrap();
571        let expected = serde_json::json!({
572            "grant_type": "refresh_token",
573            "client_id": "test-client-id",
574            "client_secret": "test-client-secret",
575            "refresh_token": "test-refresh-token",
576            "scopes": "scope1 scope2",
577        });
578        assert_eq!(json, expected);
579        let roundtrip = serde_json::from_value::<Oauth2RefreshRequest>(json).unwrap();
580        assert_eq!(request, roundtrip);
581    }
582
583    #[test]
584    fn oauth2_response_serde_full() {
585        let response = Oauth2RefreshResponse {
586            access_token: "test-access-token".to_string(),
587            scope: Some("scope1 scope2".to_string()),
588            expires_in: Some(3600),
589            token_type: "test-token-type".to_string(),
590            refresh_token: Some("test-refresh-token".to_string()),
591        };
592
593        let json = serde_json::to_value(&response).unwrap();
594        let expected = serde_json::json!({
595            "access_token": "test-access-token",
596            "scope": "scope1 scope2",
597            "expires_in": 3600,
598            "token_type": "test-token-type",
599            "refresh_token": "test-refresh-token"
600        });
601        assert_eq!(json, expected);
602        let roundtrip = serde_json::from_value::<Oauth2RefreshResponse>(json).unwrap();
603        assert_eq!(response, roundtrip);
604    }
605
606    #[test]
607    fn oauth2_response_serde_partial() {
608        let response = Oauth2RefreshResponse {
609            access_token: "test-access-token".to_string(),
610            scope: None,
611            expires_in: None,
612            token_type: "test-token-type".to_string(),
613            refresh_token: None,
614        };
615
616        let json = serde_json::to_value(&response).unwrap();
617        let expected = serde_json::json!({
618            "access_token": "test-access-token",
619            "token_type": "test-token-type",
620        });
621        assert_eq!(json, expected);
622        let roundtrip = serde_json::from_value::<Oauth2RefreshResponse>(json).unwrap();
623        assert_eq!(response, roundtrip);
624    }
625
626    // Starts a server running locally. Returns an (endpoint, handler) pair.
627    async fn start(
628        response_code: StatusCode,
629        response_body: Value,
630        call_count: Arc<Mutex<i32>>,
631    ) -> (String, JoinHandle<()>) {
632        let code = response_code;
633        let body = response_body.clone();
634        let handler = move |req| async move { handle_token_factory(code, body, call_count)(req) };
635        let app = axum::Router::new().route("/token", axum::routing::post(handler));
636        let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
637        let addr = listener.local_addr().unwrap();
638        let server = tokio::spawn(async {
639            axum::serve(listener, app).await.unwrap();
640        });
641
642        (
643            format!("http://{}:{}/token", addr.ip(), addr.port()),
644            server,
645        )
646    }
647
648    // Creates a handler that
649    // - verifies fields in an Oauth2RefreshRequest
650    // - returns a pre-canned HTTP response
651    fn handle_token_factory(
652        response_code: StatusCode,
653        response_body: Value,
654        call_count: Arc<std::sync::Mutex<i32>>,
655    ) -> impl Fn(Json<Oauth2RefreshRequest>) -> (StatusCode, String) {
656        move |request: Json<Oauth2RefreshRequest>| -> (StatusCode, String) {
657            let mut count = call_count.lock().unwrap();
658            *count += 1;
659            assert_eq!(request.client_id, "test-client-id");
660            assert_eq!(request.client_secret, "test-client-secret");
661            assert_eq!(request.refresh_token, "test-refresh-token");
662            assert_eq!(request.grant_type, RefreshGrantType::RefreshToken);
663            assert_eq!(
664                request.scopes,
665                response_body["scope"].as_str().map(|s| s.to_owned())
666            );
667
668            (response_code, response_body.to_string())
669        }
670    }
671
672    #[tokio::test(start_paused = true)]
673    async fn token_provider_full() -> TestResult {
674        let response = Oauth2RefreshResponse {
675            access_token: "test-access-token".to_string(),
676            expires_in: Some(3600),
677            refresh_token: Some("test-refresh-token".to_string()),
678            scope: Some("scope1 scope2".to_string()),
679            token_type: "test-token-type".to_string(),
680        };
681        let response_body = serde_json::to_value(&response).unwrap();
682        let (endpoint, _server) =
683            start(StatusCode::OK, response_body, Arc::new(Mutex::new(0))).await;
684        println!("endpoint = {endpoint}");
685
686        let tp = UserTokenProvider {
687            client_id: "test-client-id".to_string(),
688            client_secret: "test-client-secret".to_string(),
689            refresh_token: "test-refresh-token".to_string(),
690            endpoint,
691            scopes: Some("scope1 scope2".to_string()),
692        };
693        let now = Instant::now();
694        let token = tp.token().await?;
695        assert_eq!(token.token, "test-access-token");
696        assert_eq!(token.token_type, "test-token-type");
697        assert!(
698            token
699                .expires_at
700                .is_some_and(|d| d == now + Duration::from_secs(3600)),
701            "now: {:?}, expires_at: {:?}",
702            now,
703            token.expires_at
704        );
705
706        Ok(())
707    }
708
709    #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
710    async fn credential_full_with_quota_project() -> TestResult {
711        let response = Oauth2RefreshResponse {
712            access_token: "test-access-token".to_string(),
713            expires_in: Some(3600),
714            refresh_token: Some("test-refresh-token".to_string()),
715            scope: None,
716            token_type: "test-token-type".to_string(),
717        };
718        let response_body = serde_json::to_value(&response).unwrap();
719        let (endpoint, _server) =
720            start(StatusCode::OK, response_body, Arc::new(Mutex::new(0))).await;
721        println!("endpoint = {endpoint}");
722
723        let authorized_user = serde_json::json!({
724            "client_id": "test-client-id",
725            "client_secret": "test-client-secret",
726            "refresh_token": "test-refresh-token",
727            "type": "authorized_user",
728            "token_uri": endpoint,
729        });
730        let cred = Builder::new(authorized_user)
731            .with_quota_project_id("test-project")
732            .build()?;
733
734        let headers = get_headers_from_cache(cred.headers(Extensions::new()).await.unwrap())?;
735        let token = headers.get(AUTHORIZATION).unwrap();
736        let quota_project_header = headers.get(QUOTA_PROJECT_KEY).unwrap();
737
738        assert_eq!(headers.len(), 2, "{headers:?}");
739        assert_eq!(
740            token,
741            HeaderValue::from_static("test-token-type test-access-token")
742        );
743        assert!(token.is_sensitive());
744        assert_eq!(
745            quota_project_header,
746            HeaderValue::from_static("test-project")
747        );
748        assert!(!quota_project_header.is_sensitive());
749
750        Ok(())
751    }
752
753    #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
754    async fn creds_from_json_custom_uri_with_caching() -> TestResult {
755        let response = Oauth2RefreshResponse {
756            access_token: "test-access-token".to_string(),
757            expires_in: Some(3600),
758            refresh_token: Some("test-refresh-token".to_string()),
759            scope: Some("scope1 scope2".to_string()),
760            token_type: "test-token-type".to_string(),
761        };
762        let response_body = serde_json::to_value(&response).unwrap();
763        let call_count = Arc::new(Mutex::new(0));
764        let (endpoint, _server) = start(StatusCode::OK, response_body, call_count.clone()).await;
765        println!("endpoint = {endpoint}");
766
767        let json = serde_json::json!({
768            "client_id": "test-client-id",
769            "client_secret": "test-client-secret",
770            "refresh_token": "test-refresh-token",
771            "type": "authorized_user",
772            "universe_domain": "googleapis.com",
773            "quota_project_id": "test-project",
774            "token_uri": endpoint,
775        });
776
777        let cred = Builder::new(json)
778            .with_scopes(vec!["scope1", "scope2"])
779            .build()?;
780
781        let token = get_token_from_headers(cred.headers(Extensions::new()).await?);
782        assert_eq!(token.unwrap(), "test-access-token");
783
784        let token = get_token_from_headers(cred.headers(Extensions::new()).await?);
785        assert_eq!(token.unwrap(), "test-access-token");
786
787        // Test that the inner token provider was called only
788        // once even though token was called twice.
789        assert_eq!(*call_count.lock().unwrap(), 1);
790
791        Ok(())
792    }
793
794    #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
795    async fn credential_provider_partial() -> TestResult {
796        let response = Oauth2RefreshResponse {
797            access_token: "test-access-token".to_string(),
798            expires_in: None,
799            refresh_token: None,
800            scope: None,
801            token_type: "test-token-type".to_string(),
802        };
803        let response_body = serde_json::to_value(&response).unwrap();
804        let (endpoint, _server) =
805            start(StatusCode::OK, response_body, Arc::new(Mutex::new(0))).await;
806        println!("endpoint = {endpoint}");
807
808        let authorized_user = serde_json::json!({
809            "client_id": "test-client-id",
810            "client_secret": "test-client-secret",
811            "refresh_token": "test-refresh-token",
812            "type": "authorized_user",
813            "token_uri": endpoint});
814
815        let uc = Builder::new(authorized_user).build()?;
816        let headers = uc.headers(Extensions::new()).await?;
817        assert_eq!(
818            get_token_from_headers(headers.clone()).unwrap(),
819            "test-access-token"
820        );
821        assert_eq!(
822            get_token_type_from_headers(headers).unwrap(),
823            "test-token-type"
824        );
825
826        Ok(())
827    }
828
829    #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
830    async fn credential_provider_with_token_uri() -> TestResult {
831        let response = Oauth2RefreshResponse {
832            access_token: "test-access-token".to_string(),
833            expires_in: None,
834            refresh_token: None,
835            scope: None,
836            token_type: "test-token-type".to_string(),
837        };
838        let response_body = serde_json::to_value(&response).unwrap();
839        let (endpoint, _server) =
840            start(StatusCode::OK, response_body, Arc::new(Mutex::new(0))).await;
841        println!("endpoint = {endpoint}");
842
843        let authorized_user = serde_json::json!({
844            "client_id": "test-client-id",
845            "client_secret": "test-client-secret",
846            "refresh_token": "test-refresh-token",
847            "type": "authorized_user",
848            "token_uri": "test-endpoint"});
849
850        let uc = Builder::new(authorized_user)
851            .with_token_uri(endpoint)
852            .build()?;
853        let headers = uc.headers(Extensions::new()).await?;
854        assert_eq!(
855            get_token_from_headers(headers.clone()).unwrap(),
856            "test-access-token"
857        );
858        assert_eq!(
859            get_token_type_from_headers(headers).unwrap(),
860            "test-token-type"
861        );
862
863        Ok(())
864    }
865
866    #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
867    async fn credential_provider_with_scopes() -> TestResult {
868        let response = Oauth2RefreshResponse {
869            access_token: "test-access-token".to_string(),
870            expires_in: None,
871            refresh_token: None,
872            scope: Some("scope1 scope2".to_string()),
873            token_type: "test-token-type".to_string(),
874        };
875        let response_body = serde_json::to_value(&response).unwrap();
876        let (endpoint, _server) =
877            start(StatusCode::OK, response_body, Arc::new(Mutex::new(0))).await;
878        println!("endpoint = {endpoint}");
879
880        let authorized_user = serde_json::json!({
881            "client_id": "test-client-id",
882            "client_secret": "test-client-secret",
883            "refresh_token": "test-refresh-token",
884            "type": "authorized_user",
885            "token_uri": "test-endpoint"});
886
887        let uc = Builder::new(authorized_user)
888            .with_token_uri(endpoint)
889            .with_scopes(vec!["scope1", "scope2"])
890            .build()?;
891        let headers = uc.headers(Extensions::new()).await?;
892        assert_eq!(
893            get_token_from_headers(headers.clone()).unwrap(),
894            "test-access-token"
895        );
896        assert_eq!(
897            get_token_type_from_headers(headers).unwrap(),
898            "test-token-type"
899        );
900
901        Ok(())
902    }
903
904    #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
905    async fn credential_provider_retryable_error() -> TestResult {
906        let (endpoint, _server) = start(
907            StatusCode::SERVICE_UNAVAILABLE,
908            serde_json::to_value("try again".to_string())?,
909            Arc::new(Mutex::new(0)),
910        )
911        .await;
912        println!("endpoint = {endpoint}");
913
914        let authorized_user = serde_json::json!({
915            "client_id": "test-client-id",
916            "client_secret": "test-client-secret",
917            "refresh_token": "test-refresh-token",
918            "type": "authorized_user",
919            "token_uri": endpoint});
920
921        let uc = Builder::new(authorized_user).build()?;
922        let err = uc.headers(Extensions::new()).await.unwrap_err();
923        assert!(err.is_transient(), "{err}");
924        assert!(err.to_string().contains("try again"), "{err:?}");
925        let source = err
926            .source()
927            .and_then(|e| e.downcast_ref::<reqwest::Error>());
928        assert!(
929            matches!(source, Some(e) if e.status() == Some(StatusCode::SERVICE_UNAVAILABLE)),
930            "{err:?}"
931        );
932
933        Ok(())
934    }
935
936    #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
937    async fn token_provider_nonretryable_error() -> TestResult {
938        let (endpoint, _server) = start(
939            StatusCode::UNAUTHORIZED,
940            serde_json::to_value("epic fail".to_string())?,
941            Arc::new(Mutex::new(0)),
942        )
943        .await;
944        println!("endpoint = {endpoint}");
945
946        let authorized_user = serde_json::json!({
947            "client_id": "test-client-id",
948            "client_secret": "test-client-secret",
949            "refresh_token": "test-refresh-token",
950            "type": "authorized_user",
951            "token_uri": endpoint});
952
953        let uc = Builder::new(authorized_user).build()?;
954        let err = uc.headers(Extensions::new()).await.unwrap_err();
955        assert!(!err.is_transient(), "{err:?}");
956        assert!(err.to_string().contains("epic fail"), "{err:?}");
957        let source = err
958            .source()
959            .and_then(|e| e.downcast_ref::<reqwest::Error>());
960        assert!(
961            matches!(source, Some(e) if e.status() == Some(StatusCode::UNAUTHORIZED)),
962            "{err:?}"
963        );
964
965        Ok(())
966    }
967
968    #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
969    async fn token_provider_malformed_response_is_nonretryable() -> TestResult {
970        let (endpoint, _server) = start(
971            StatusCode::OK,
972            serde_json::to_value("bad json".to_string())?,
973            Arc::new(Mutex::new(0)),
974        )
975        .await;
976        println!("endpoint = {endpoint}");
977
978        let authorized_user = serde_json::json!({
979            "client_id": "test-client-id",
980            "client_secret": "test-client-secret",
981            "refresh_token": "test-refresh-token",
982            "type": "authorized_user",
983            "token_uri": endpoint
984        });
985
986        let uc = Builder::new(authorized_user).build()?;
987        let e = uc.headers(Extensions::new()).await.err().unwrap();
988        assert!(!e.is_transient(), "{e}");
989
990        Ok(())
991    }
992
993    #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
994    async fn builder_malformed_authorized_json_nonretryable() -> TestResult {
995        let authorized_user = serde_json::json!({
996            "client_secret": "test-client-secret",
997            "refresh_token": "test-refresh-token",
998            "type": "authorized_user",
999        });
1000
1001        let e = Builder::new(authorized_user).build().unwrap_err();
1002        assert!(e.is_parsing(), "{e}");
1003
1004        Ok(())
1005    }
1006}