google_cloud_auth/credentials/
user_account.rs

1// Copyright 2024 Google LLC
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     https://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15//! [User Account] Credentials type.
16//!
17//! User accounts represent a developer, administrator, or any other person who
18//! interacts with Google APIs and services. User accounts are managed as
19//! [Google Accounts], either via [Google Workspace] or [Cloud Identity].
20//!
21//! This module provides [Credentials] derived from user account
22//! information, specifically utilizing an OAuth 2.0 refresh token.
23//!
24//! This module is designed for refresh tokens obtained via the standard
25//! [Authorization Code grant]. Acquiring the initial refresh token (e.g., through
26//! user consent) is outside the scope of this library.
27//! See [RFC 6749 Section 4.1] for flow details.
28//!
29//! The Google Cloud client libraries for Rust will typically find and use these
30//! credentials automatically if a credentials file exists in the
31//! standard ADC search paths. This file is often created by running:
32//! `gcloud auth application-default login`. You might instantiate these credentials
33//! directly using the [`Builder`] if you need to:
34//! * Load credentials from a non-standard location or source.
35//! * Override the OAuth 2.0 **scopes** being requested for the access token.
36//! * Override the **quota project ID** for billing and quota management.
37//! * Override the **token URI** used to fetch access tokens.
38//! * Customize the **retry behavior** when fetching access tokens.
39//!
40//! ## Example: Creating credentials from a JSON object
41//!
42//! ```
43//! # use google_cloud_auth::credentials::user_account::Builder;
44//! # use google_cloud_auth::credentials::Credentials;
45//! # use http::Extensions;
46//! # tokio_test::block_on(async {
47//! let authorized_user = serde_json::json!({
48//!     "client_id": "YOUR_CLIENT_ID.apps.googleusercontent.com", // Replace with your actual Client ID
49//!     "client_secret": "YOUR_CLIENT_SECRET", // Replace with your actual Client Secret - LOAD SECURELY!
50//!     "refresh_token": "YOUR_REFRESH_TOKEN", // Replace with the user's refresh token - LOAD SECURELY!
51//!     "type": "authorized_user",
52//!     // "quota_project_id": "your-billing-project-id", // Optional: Set if needed
53//!     // "token_uri" : "test-token-uri", // Optional: Set if needed
54//! });
55//! let credentials: Credentials = Builder::new(authorized_user).build()?;
56//! let headers = credentials.headers(Extensions::new()).await?;
57//! println!("Headers: {headers:?}");
58//! # Ok::<(), anyhow::Error>(())
59//! # });
60//! ```
61//!
62//! ## Example: Creating credentials with custom retry behavior
63//!
64//! ```
65//! # use google_cloud_auth::credentials::user_account::Builder;
66//! # use google_cloud_auth::credentials::Credentials;
67//! # use http::Extensions;
68//! # use std::time::Duration;
69//! # tokio_test::block_on(async {
70//! use gax::retry_policy::{AlwaysRetry, RetryPolicyExt};
71//! use gax::exponential_backoff::ExponentialBackoff;
72//! let authorized_user = serde_json::json!({
73//!     "client_id": "YOUR_CLIENT_ID.apps.googleusercontent.com",
74//!     "client_secret": "YOUR_CLIENT_SECRET",
75//!     "refresh_token": "YOUR_REFRESH_TOKEN",
76//!     "type": "authorized_user",
77//! });
78//! let backoff = ExponentialBackoff::default();
79//! let credentials: Credentials = Builder::new(authorized_user)
80//!     .with_retry_policy(AlwaysRetry.with_attempt_limit(3))
81//!     .with_backoff_policy(backoff)
82//!     .build()?;
83//! let headers = credentials.headers(Extensions::new()).await?;
84//! println!("Headers: {headers:?}");
85//! # Ok::<(), anyhow::Error>(())
86//! # });
87//! ```
88//!
89//! [Authorization Code grant]: https://tools.ietf.org/html/rfc6749#section-1.3.1
90//! [Cloud Identity]: https://cloud.google.com/identity
91//! [Google Accounts]: https://myaccount.google.com/
92//! [Google Workspace]: https://workspace.google.com/
93//! [RFC 6749 Section 4.1]: https://datatracker.ietf.org/doc/html/rfc6749#section-4.1
94//! [User Account]: https://cloud.google.com/docs/authentication#user-accounts
95//! [Workforce Identity Federation]: https://cloud.google.com/iam/docs/workforce-identity-federation
96
97use crate::build_errors::Error as BuilderError;
98use crate::constants::OAUTH2_TOKEN_SERVER_URL;
99use crate::credentials::dynamic::CredentialsProvider;
100use crate::credentials::{CacheableResource, Credentials};
101use crate::errors::{self, CredentialsError};
102use crate::headers_util::build_cacheable_headers;
103use crate::retry::Builder as RetryTokenProviderBuilder;
104use crate::token::{CachedTokenProvider, Token, TokenProvider};
105use crate::token_cache::TokenCache;
106use crate::{BuildResult, Result};
107use gax::backoff_policy::BackoffPolicyArg;
108use gax::retry_policy::RetryPolicyArg;
109use gax::retry_throttler::RetryThrottlerArg;
110use http::header::CONTENT_TYPE;
111use http::{Extensions, HeaderMap, HeaderValue};
112use reqwest::{Client, Method};
113use serde_json::Value;
114use std::sync::Arc;
115use tokio::time::{Duration, Instant};
116
117/// A builder for constructing `user_account` [Credentials] instance.
118///
119/// # Example
120/// ```
121/// # use google_cloud_auth::credentials::user_account::Builder;
122/// # tokio_test::block_on(async {
123/// let authorized_user = serde_json::json!({ /* add details here */ });
124/// let credentials = Builder::new(authorized_user).build();
125/// })
126/// ```
127pub struct Builder {
128    authorized_user: Value,
129    scopes: Option<Vec<String>>,
130    quota_project_id: Option<String>,
131    token_uri: Option<String>,
132    retry_builder: RetryTokenProviderBuilder,
133}
134
135impl Builder {
136    /// Creates a new builder using `authorized_user` JSON value.
137    ///
138    /// The `authorized_user` JSON is typically generated when a user
139    /// authenticates using the [application-default login] process.
140    ///
141    /// [application-default login]: https://cloud.google.com/sdk/gcloud/reference/auth/application-default/login
142    pub fn new(authorized_user: Value) -> Self {
143        Self {
144            authorized_user,
145            scopes: None,
146            quota_project_id: None,
147            token_uri: None,
148            retry_builder: RetryTokenProviderBuilder::default(),
149        }
150    }
151
152    /// Sets the URI for the token endpoint used to fetch access tokens.
153    ///
154    /// Any value provided here overrides a `token_uri` value from the input `authorized_user` JSON.
155    /// Defaults to `https://oauth2.googleapis.com/token` if not specified here or in the `authorized_user` JSON.
156    ///
157    /// # Example
158    /// ```
159    /// # use google_cloud_auth::credentials::user_account::Builder;
160    /// let authorized_user = serde_json::json!({ /* add details here */ });
161    /// let credentials = Builder::new(authorized_user)
162    ///     .with_token_uri("https://oauth2-FOOBAR.p.googleapis.com")
163    ///     .build();
164    /// ```
165    pub fn with_token_uri<S: Into<String>>(mut self, token_uri: S) -> Self {
166        self.token_uri = Some(token_uri.into());
167        self
168    }
169
170    /// Sets the [scopes] for these credentials.
171    ///
172    /// `scopes` define the *permissions being requested* for this specific access token
173    /// when interacting with a service. For example, `https://www.googleapis.com/auth/devstorage.read_write`.
174    /// IAM permissions, on the other hand, define the *underlying capabilities*
175    /// the user account possesses within a system. For example, `storage.buckets.delete`.
176    /// When a token generated with specific scopes is used, the request must be permitted
177    /// by both the user account's underlying IAM permissions and the scopes requested
178    /// for the token. Therefore, scopes act as an additional restriction on what the token
179    /// can be used for.
180    ///
181    /// # Example
182    /// ```
183    /// # use google_cloud_auth::credentials::user_account::Builder;
184    /// let authorized_user = serde_json::json!({ /* add details here */ });
185    /// let credentials = Builder::new(authorized_user)
186    ///     .with_scopes(["https://www.googleapis.com/auth/pubsub"])
187    ///     .build();
188    /// ```
189    /// [scopes]: https://developers.google.com/identity/protocols/oauth2/scopes
190    pub fn with_scopes<I, S>(mut self, scopes: I) -> Self
191    where
192        I: IntoIterator<Item = S>,
193        S: Into<String>,
194    {
195        self.scopes = Some(scopes.into_iter().map(|s| s.into()).collect());
196        self
197    }
198
199    /// Sets the [quota project] for these credentials.
200    ///
201    /// In some services, you can use an account in
202    /// one project for authentication and authorization, and charge
203    /// the usage to a different project. This requires that the
204    /// user has `serviceusage.services.use` permissions on the quota project.
205    ///
206    /// Any value set here overrides a `quota_project_id` value from the
207    /// input `authorized_user` JSON.
208    ///
209    /// # Example
210    /// ```
211    /// # use google_cloud_auth::credentials::user_account::Builder;
212    /// let authorized_user = serde_json::json!("{ /* add details here */ }");
213    /// let credentials = Builder::new(authorized_user)
214    ///     .with_quota_project_id("my-project")
215    ///     .build();
216    /// ```
217    ///
218    /// [quota project]: https://cloud.google.com/docs/quotas/quota-project
219    pub fn with_quota_project_id<S: Into<String>>(mut self, quota_project_id: S) -> Self {
220        self.quota_project_id = Some(quota_project_id.into());
221        self
222    }
223
224    /// Configure the retry policy for fetching tokens.
225    ///
226    /// The retry policy controls how to handle retries, and sets limits on
227    /// the number of attempts or the total time spent retrying.
228    ///
229    /// ```
230    /// # use google_cloud_auth::credentials::user_account::Builder;
231    /// # tokio_test::block_on(async {
232    /// use gax::retry_policy::{AlwaysRetry, RetryPolicyExt};
233    /// let authorized_user = serde_json::json!({
234    ///     "client_id": "YOUR_CLIENT_ID.apps.googleusercontent.com",
235    ///     "client_secret": "YOUR_CLIENT_SECRET",
236    ///     "refresh_token": "YOUR_REFRESH_TOKEN",
237    ///     "type": "authorized_user",
238    /// });
239    /// let credentials = Builder::new(authorized_user)
240    ///     .with_retry_policy(AlwaysRetry.with_attempt_limit(3))
241    ///     .build();
242    /// # });
243    /// ```
244    pub fn with_retry_policy<V: Into<RetryPolicyArg>>(mut self, v: V) -> Self {
245        self.retry_builder = self.retry_builder.with_retry_policy(v.into());
246        self
247    }
248
249    /// Configure the retry backoff policy.
250    ///
251    /// The backoff policy controls how long to wait in between retry attempts.
252    ///
253    /// ```
254    /// # use google_cloud_auth::credentials::user_account::Builder;
255    /// # use std::time::Duration;
256    /// # tokio_test::block_on(async {
257    /// use gax::exponential_backoff::ExponentialBackoff;
258    /// let authorized_user = serde_json::json!({
259    ///     "client_id": "YOUR_CLIENT_ID.apps.googleusercontent.com",
260    ///     "client_secret": "YOUR_CLIENT_SECRET",
261    ///     "refresh_token": "YOUR_REFRESH_TOKEN",
262    ///     "type": "authorized_user",
263    /// });
264    /// let credentials = Builder::new(authorized_user)
265    ///     .with_backoff_policy(ExponentialBackoff::default())
266    ///     .build();
267    /// # });
268    /// ```
269    pub fn with_backoff_policy<V: Into<BackoffPolicyArg>>(mut self, v: V) -> Self {
270        self.retry_builder = self.retry_builder.with_backoff_policy(v.into());
271        self
272    }
273
274    /// Configure the retry throttler.
275    ///
276    /// Advanced applications may want to configure a retry throttler to
277    /// [Address Cascading Failures] and when [Handling Overload] conditions.
278    /// The authentication library throttles its retry loop, using a policy to
279    /// control the throttling algorithm. Use this method to fine tune or
280    /// customize the default retry throttler.
281    ///
282    /// [Handling Overload]: https://sre.google/sre-book/handling-overload/
283    /// [Address Cascading Failures]: https://sre.google/sre-book/addressing-cascading-failures/
284    ///
285    /// ```
286    /// # use google_cloud_auth::credentials::user_account::Builder;
287    /// # tokio_test::block_on(async {
288    /// use gax::retry_throttler::AdaptiveThrottler;
289    /// let authorized_user = serde_json::json!({
290    ///     "client_id": "YOUR_CLIENT_ID.apps.googleusercontent.com",
291    ///     "client_secret": "YOUR_CLIENT_SECRET",
292    ///     "refresh_token": "YOUR_REFRESH_TOKEN",
293    ///     "type": "authorized_user",
294    /// });
295    /// let credentials = Builder::new(authorized_user)
296    ///     .with_retry_throttler(AdaptiveThrottler::default())
297    ///     .build();
298    /// # });
299    /// ```
300    pub fn with_retry_throttler<V: Into<RetryThrottlerArg>>(mut self, v: V) -> Self {
301        self.retry_builder = self.retry_builder.with_retry_throttler(v.into());
302        self
303    }
304
305    /// Returns a [Credentials] instance with the configured settings.
306    ///
307    /// # Errors
308    ///
309    /// Returns a [CredentialsError] if the `authorized_user`
310    /// provided to [`Builder::new`] cannot be successfully deserialized into the
311    /// expected format. This typically happens if the JSON value is malformed or
312    /// missing required fields. For more information, on how to generate
313    /// `authorized_user` json, consult the relevant section in the
314    /// [application-default credentials] guide.
315    ///
316    /// [application-default credentials]: https://cloud.google.com/docs/authentication/application-default-credentials
317    pub fn build(self) -> BuildResult<Credentials> {
318        let authorized_user = serde_json::from_value::<AuthorizedUser>(self.authorized_user)
319            .map_err(BuilderError::parsing)?;
320        let endpoint = self
321            .token_uri
322            .or(authorized_user.token_uri)
323            .unwrap_or(OAUTH2_TOKEN_SERVER_URL.to_string());
324        let quota_project_id = self.quota_project_id.or(authorized_user.quota_project_id);
325
326        let token_provider = UserTokenProvider {
327            client_id: authorized_user.client_id,
328            client_secret: authorized_user.client_secret,
329            refresh_token: authorized_user.refresh_token,
330            endpoint,
331            scopes: self.scopes.map(|scopes| scopes.join(" ")),
332            source: UserTokenSource::AccessToken,
333        };
334
335        let token_provider = TokenCache::new(self.retry_builder.build(token_provider));
336
337        Ok(Credentials {
338            inner: Arc::new(UserCredentials {
339                token_provider,
340                quota_project_id,
341            }),
342        })
343    }
344}
345
346#[derive(PartialEq)]
347struct UserTokenProvider {
348    client_id: String,
349    client_secret: String,
350    refresh_token: String,
351    endpoint: String,
352    scopes: Option<String>,
353    source: UserTokenSource,
354}
355
356#[allow(dead_code)]
357#[derive(PartialEq)]
358enum UserTokenSource {
359    IdToken,
360    AccessToken,
361}
362
363impl std::fmt::Debug for UserTokenProvider {
364    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
365        f.debug_struct("UserCredentials")
366            .field("client_id", &self.client_id)
367            .field("client_secret", &"[censored]")
368            .field("refresh_token", &"[censored]")
369            .field("endpoint", &self.endpoint)
370            .field("scopes", &self.scopes)
371            .finish()
372    }
373}
374
375#[async_trait::async_trait]
376impl TokenProvider for UserTokenProvider {
377    async fn token(&self) -> Result<Token> {
378        let client = Client::new();
379
380        // Make the request
381        let req = Oauth2RefreshRequest {
382            grant_type: RefreshGrantType::RefreshToken,
383            client_id: self.client_id.clone(),
384            client_secret: self.client_secret.clone(),
385            refresh_token: self.refresh_token.clone(),
386            scopes: self.scopes.clone(),
387        };
388        let header = HeaderValue::from_static("application/json");
389        let builder = client
390            .request(Method::POST, self.endpoint.as_str())
391            .header(CONTENT_TYPE, header)
392            .json(&req);
393        let resp = builder
394            .send()
395            .await
396            .map_err(|e| errors::from_http_error(e, MSG))?;
397
398        // Process the response
399        if !resp.status().is_success() {
400            let err = errors::from_http_response(resp, MSG).await;
401            return Err(err);
402        }
403        let response = resp.json::<Oauth2RefreshResponse>().await.map_err(|e| {
404            let retryable = !e.is_decode();
405            CredentialsError::from_source(retryable, e)
406        })?;
407
408        let token = match self.source {
409            UserTokenSource::AccessToken => Ok(response.access_token),
410            UserTokenSource::IdToken => response
411                .id_token
412                .ok_or_else(|| CredentialsError::from_msg(false, MISSING_ID_TOKEN_MSG)),
413        }?;
414        let token = Token {
415            token,
416            token_type: response.token_type,
417            expires_at: response
418                .expires_in
419                .map(|d| Instant::now() + Duration::from_secs(d)),
420            metadata: None,
421        };
422        Ok(token)
423    }
424}
425
426const MSG: &str = "failed to refresh user access token";
427const MISSING_ID_TOKEN_MSG: &str = "UserCredentials can obtain an id token only when authenticated through \
428gcloud running 'gcloud auth application-default login`";
429
430/// Data model for a UserCredentials
431///
432/// See: https://cloud.google.com/docs/authentication#user-accounts
433#[derive(Debug)]
434pub(crate) struct UserCredentials<T>
435where
436    T: CachedTokenProvider,
437{
438    token_provider: T,
439    quota_project_id: Option<String>,
440}
441
442#[async_trait::async_trait]
443impl<T> CredentialsProvider for UserCredentials<T>
444where
445    T: CachedTokenProvider,
446{
447    async fn headers(&self, extensions: Extensions) -> Result<CacheableResource<HeaderMap>> {
448        let token = self.token_provider.token(extensions).await?;
449        build_cacheable_headers(&token, &self.quota_project_id)
450    }
451}
452
453#[derive(Debug, PartialEq, serde::Deserialize)]
454pub(crate) struct AuthorizedUser {
455    #[serde(rename = "type")]
456    cred_type: String,
457    client_id: String,
458    client_secret: String,
459    refresh_token: String,
460    #[serde(skip_serializing_if = "Option::is_none")]
461    token_uri: Option<String>,
462    #[serde(skip_serializing_if = "Option::is_none")]
463    quota_project_id: Option<String>,
464}
465
466#[derive(Clone, Debug, PartialEq, serde::Deserialize, serde::Serialize)]
467enum RefreshGrantType {
468    #[serde(rename = "refresh_token")]
469    RefreshToken,
470}
471
472#[derive(Clone, Debug, PartialEq, serde::Deserialize, serde::Serialize)]
473struct Oauth2RefreshRequest {
474    grant_type: RefreshGrantType,
475    client_id: String,
476    client_secret: String,
477    refresh_token: String,
478    scopes: Option<String>,
479}
480
481#[derive(Clone, Debug, PartialEq, serde::Deserialize, serde::Serialize)]
482struct Oauth2RefreshResponse {
483    access_token: String,
484    #[serde(skip_serializing_if = "Option::is_none")]
485    id_token: Option<String>,
486    #[serde(skip_serializing_if = "Option::is_none")]
487    scope: Option<String>,
488    #[serde(skip_serializing_if = "Option::is_none")]
489    expires_in: Option<u64>,
490    token_type: String,
491    #[serde(skip_serializing_if = "Option::is_none")]
492    refresh_token: Option<String>,
493}
494
495#[cfg(google_cloud_unstable_id_token)]
496pub mod idtoken {
497    /// Credentials for authenticating with [ID tokens] from a [user account].
498    ///
499    /// This module provides a builder for [`IDTokenCredentials`] from
500    /// authorized user credentials, which are typically obtained by running
501    /// `gcloud auth application-default login`.
502    ///
503    /// These credentials are commonly used for [service to service authentication].
504    /// For example, when services are hosted in Cloud Run or mediated by Identity-Aware Proxy (IAP).
505    /// ID tokens are only used to verify the identity of a principal. Google Cloud APIs do not use ID tokens
506    /// for authorization, and therefore cannot be used to access Google Cloud APIs.
507    ///
508    /// [ID tokens]: https://cloud.google.com/docs/authentication/token-types#identity-tokens
509    /// [user account]: https://cloud.google.com/docs/authentication#user-accounts
510    /// [Service to Service Authentication]: https://cloud.google.com/run/docs/authenticating/service-to-service
511    use crate::build_errors::Error as BuilderError;
512    use crate::constants::OAUTH2_TOKEN_SERVER_URL;
513    use crate::{
514        BuildResult, Result,
515        credentials::{
516            idtoken::{IDTokenCredentials, dynamic::IDTokenCredentialsProvider},
517            user_account::AuthorizedUser,
518        },
519        token::TokenProvider,
520    };
521    use async_trait::async_trait;
522    use serde_json::Value;
523    use std::sync::Arc;
524
525    #[derive(Debug)]
526    struct UserAccountCredentials<T>
527    where
528        T: TokenProvider,
529    {
530        token_provider: T,
531    }
532
533    #[async_trait]
534    impl<T> IDTokenCredentialsProvider for UserAccountCredentials<T>
535    where
536        T: TokenProvider,
537    {
538        async fn id_token(&self) -> Result<String> {
539            self.token_provider.token().await.map(|token| token.token)
540        }
541    }
542
543    /// A builder for [`IDTokenCredentials`] instances backed by user account credentials.
544    pub struct Builder {
545        authorized_user: Value,
546        token_uri: Option<String>,
547    }
548
549    impl Builder {
550        /// Creates a new builder for `IDTokenCredentials` from a `serde_json::Value`
551        /// representing the authorized user credentials.
552        ///
553        /// The `authorized_user` JSON is typically generated when a user
554        /// authenticates using the [application-default login] process.
555        ///
556        /// [application-default login]: https://cloud.google.com/sdk/gcloud/reference/auth/application-default/login
557        pub fn new(authorized_user: Value) -> Self {
558            Self {
559                authorized_user,
560                token_uri: None,
561            }
562        }
563
564        /// Sets the URI for the token endpoint used to fetch access tokens.
565        ///
566        /// Any value provided here overrides a `token_uri` value from the input `authorized_user` JSON.
567        /// Defaults to `https://oauth2.googleapis.com/token` if not specified here or in the `authorized_user` JSON.
568        pub fn with_token_uri<S: Into<String>>(mut self, token_uri: S) -> Self {
569            self.token_uri = Some(token_uri.into());
570            self
571        }
572
573        fn build_token_provider(self) -> BuildResult<super::UserTokenProvider> {
574            let authorized_user = serde_json::from_value::<AuthorizedUser>(self.authorized_user)
575                .map_err(BuilderError::parsing)?;
576            let endpoint = self
577                .token_uri
578                .or(authorized_user.token_uri)
579                .unwrap_or(OAUTH2_TOKEN_SERVER_URL.to_string());
580            Ok(super::UserTokenProvider {
581                client_id: authorized_user.client_id,
582                client_secret: authorized_user.client_secret,
583                refresh_token: authorized_user.refresh_token,
584                endpoint,
585                source: super::UserTokenSource::IdToken,
586                scopes: None,
587            })
588        }
589
590        /// Returns an [`IDTokenCredentials`] instance with the configured
591        /// settings.
592        ///
593        /// # Errors
594        ///
595        /// Returns a `BuildError` if the `authorized_user`
596        /// provided to [`Builder::new`] cannot be successfully deserialized into the
597        /// expected format. This typically happens if the JSON value is malformed or
598        /// missing required fields. For more information on how to generate
599        /// `authorized_user` json, consult the relevant section in the
600        /// [application-default credentials] guide.
601        ///
602        /// [application-default credentials]: https://cloud.google.com/docs/authentication/application-default-credentials
603        pub fn build(self) -> BuildResult<IDTokenCredentials> {
604            let creds = UserAccountCredentials {
605                token_provider: self.build_token_provider()?,
606            };
607            Ok(IDTokenCredentials {
608                inner: Arc::new(creds),
609            })
610        }
611    }
612}
613
614#[cfg(test)]
615mod tests {
616    use super::*;
617    use crate::credentials::tests::{
618        find_source_error, get_headers_from_cache, get_mock_auth_retry_policy,
619        get_mock_backoff_policy, get_mock_retry_throttler, get_token_from_headers,
620        get_token_type_from_headers,
621    };
622    use crate::credentials::{DEFAULT_UNIVERSE_DOMAIN, QUOTA_PROJECT_KEY};
623    use crate::errors::CredentialsError;
624    use crate::token::tests::MockTokenProvider;
625    use http::StatusCode;
626    use http::header::AUTHORIZATION;
627    use httptest::cycle;
628    use httptest::matchers::{all_of, json_decoded, request};
629    use httptest::responders::{json_encoded, status_code};
630    use httptest::{Expectation, Server};
631
632    type TestResult = anyhow::Result<()>;
633
634    pub(crate) fn authorized_user_json(token_uri: String) -> Value {
635        serde_json::json!({
636            "client_id": "test-client-id",
637            "client_secret": "test-client-secret",
638            "refresh_token": "test-refresh-token",
639            "type": "authorized_user",
640            "token_uri": token_uri,
641        })
642    }
643
644    #[tokio::test]
645    async fn test_user_account_retries_on_transient_failures() -> TestResult {
646        let mut server = Server::run();
647        server.expect(
648            Expectation::matching(request::path("/token"))
649                .times(3)
650                .respond_with(status_code(503)),
651        );
652
653        let credentials = Builder::new(authorized_user_json(server.url("/token").to_string()))
654            .with_retry_policy(get_mock_auth_retry_policy(3))
655            .with_backoff_policy(get_mock_backoff_policy())
656            .with_retry_throttler(get_mock_retry_throttler())
657            .build()?;
658
659        let err = credentials.headers(Extensions::new()).await.unwrap_err();
660        assert!(!err.is_transient());
661        server.verify_and_clear();
662        Ok(())
663    }
664
665    #[tokio::test]
666    async fn test_user_account_does_not_retry_on_non_transient_failures() -> TestResult {
667        let mut server = Server::run();
668        server.expect(
669            Expectation::matching(request::path("/token"))
670                .times(1)
671                .respond_with(status_code(401)),
672        );
673
674        let credentials = Builder::new(authorized_user_json(server.url("/token").to_string()))
675            .with_retry_policy(get_mock_auth_retry_policy(1))
676            .with_backoff_policy(get_mock_backoff_policy())
677            .with_retry_throttler(get_mock_retry_throttler())
678            .build()?;
679
680        let err = credentials.headers(Extensions::new()).await.unwrap_err();
681        assert!(!err.is_transient());
682        server.verify_and_clear();
683        Ok(())
684    }
685
686    #[tokio::test]
687    async fn test_user_account_retries_for_success() -> TestResult {
688        let mut server = Server::run();
689        let response = Oauth2RefreshResponse {
690            access_token: "test-access-token".to_string(),
691            id_token: None,
692            expires_in: Some(3600),
693            refresh_token: Some("test-refresh-token".to_string()),
694            scope: Some("scope1 scope2".to_string()),
695            token_type: "test-token-type".to_string(),
696        };
697
698        server.expect(
699            Expectation::matching(request::path("/token"))
700                .times(3)
701                .respond_with(cycle![
702                    status_code(503).body("try-again"),
703                    status_code(503).body("try-again"),
704                    status_code(200)
705                        .append_header("Content-Type", "application/json")
706                        .body(serde_json::to_string(&response).unwrap()),
707                ]),
708        );
709
710        let credentials = Builder::new(authorized_user_json(server.url("/token").to_string()))
711            .with_retry_policy(get_mock_auth_retry_policy(3))
712            .with_backoff_policy(get_mock_backoff_policy())
713            .with_retry_throttler(get_mock_retry_throttler())
714            .build()?;
715
716        let token = get_token_from_headers(credentials.headers(Extensions::new()).await.unwrap());
717        assert_eq!(token.unwrap(), "test-access-token");
718
719        server.verify_and_clear();
720        Ok(())
721    }
722
723    #[test]
724    fn debug_token_provider() {
725        let expected = UserTokenProvider {
726            client_id: "test-client-id".to_string(),
727            client_secret: "test-client-secret".to_string(),
728            refresh_token: "test-refresh-token".to_string(),
729            endpoint: OAUTH2_TOKEN_SERVER_URL.to_string(),
730            scopes: Some("https://www.googleapis.com/auth/pubsub".to_string()),
731            source: UserTokenSource::AccessToken,
732        };
733        let fmt = format!("{expected:?}");
734        assert!(fmt.contains("test-client-id"), "{fmt}");
735        assert!(!fmt.contains("test-client-secret"), "{fmt}");
736        assert!(!fmt.contains("test-refresh-token"), "{fmt}");
737        assert!(fmt.contains(OAUTH2_TOKEN_SERVER_URL), "{fmt}");
738        assert!(
739            fmt.contains("https://www.googleapis.com/auth/pubsub"),
740            "{fmt}"
741        );
742    }
743
744    #[test]
745    fn authorized_user_full_from_json_success() {
746        let json = serde_json::json!({
747            "account": "",
748            "client_id": "test-client-id",
749            "client_secret": "test-client-secret",
750            "refresh_token": "test-refresh-token",
751            "type": "authorized_user",
752            "universe_domain": "googleapis.com",
753            "quota_project_id": "test-project",
754            "token_uri" : "test-token-uri",
755        });
756
757        let expected = AuthorizedUser {
758            cred_type: "authorized_user".to_string(),
759            client_id: "test-client-id".to_string(),
760            client_secret: "test-client-secret".to_string(),
761            refresh_token: "test-refresh-token".to_string(),
762            quota_project_id: Some("test-project".to_string()),
763            token_uri: Some("test-token-uri".to_string()),
764        };
765        let actual = serde_json::from_value::<AuthorizedUser>(json).unwrap();
766        assert_eq!(actual, expected);
767    }
768
769    #[test]
770    fn authorized_user_partial_from_json_success() {
771        let json = serde_json::json!({
772            "client_id": "test-client-id",
773            "client_secret": "test-client-secret",
774            "refresh_token": "test-refresh-token",
775            "type": "authorized_user",
776        });
777
778        let expected = AuthorizedUser {
779            cred_type: "authorized_user".to_string(),
780            client_id: "test-client-id".to_string(),
781            client_secret: "test-client-secret".to_string(),
782            refresh_token: "test-refresh-token".to_string(),
783            quota_project_id: None,
784            token_uri: None,
785        };
786        let actual = serde_json::from_value::<AuthorizedUser>(json).unwrap();
787        assert_eq!(actual, expected);
788    }
789
790    #[test]
791    fn authorized_user_from_json_parse_fail() {
792        let json_full = serde_json::json!({
793            "client_id": "test-client-id",
794            "client_secret": "test-client-secret",
795            "refresh_token": "test-refresh-token",
796            "type": "authorized_user",
797            "quota_project_id": "test-project"
798        });
799
800        for required_field in ["client_id", "client_secret", "refresh_token"] {
801            let mut json = json_full.clone();
802            // Remove a required field from the JSON
803            json[required_field].take();
804            serde_json::from_value::<AuthorizedUser>(json)
805                .err()
806                .unwrap();
807        }
808    }
809
810    #[tokio::test]
811    async fn default_universe_domain_success() {
812        let mock = TokenCache::new(MockTokenProvider::new());
813
814        let uc = UserCredentials {
815            token_provider: mock,
816            quota_project_id: None,
817        };
818        assert_eq!(uc.universe_domain().await.unwrap(), DEFAULT_UNIVERSE_DOMAIN);
819    }
820
821    #[tokio::test]
822    async fn headers_success() -> TestResult {
823        let token = Token {
824            token: "test-token".to_string(),
825            token_type: "Bearer".to_string(),
826            expires_at: None,
827            metadata: None,
828        };
829
830        let mut mock = MockTokenProvider::new();
831        mock.expect_token().times(1).return_once(|| Ok(token));
832
833        let uc = UserCredentials {
834            token_provider: TokenCache::new(mock),
835            quota_project_id: None,
836        };
837
838        let mut extensions = Extensions::new();
839        let cached_headers = uc.headers(extensions.clone()).await.unwrap();
840        let (headers, entity_tag) = match cached_headers {
841            CacheableResource::New { entity_tag, data } => (data, entity_tag),
842            CacheableResource::NotModified => unreachable!("expecting new headers"),
843        };
844        let token = headers.get(AUTHORIZATION).unwrap();
845
846        assert_eq!(headers.len(), 1, "{headers:?}");
847        assert_eq!(token, HeaderValue::from_static("Bearer test-token"));
848        assert!(token.is_sensitive());
849
850        extensions.insert(entity_tag);
851
852        let cached_headers = uc.headers(extensions).await?;
853
854        match cached_headers {
855            CacheableResource::New { .. } => unreachable!("expecting new headers"),
856            CacheableResource::NotModified => CacheableResource::<HeaderMap>::NotModified,
857        };
858        Ok(())
859    }
860
861    #[tokio::test]
862    async fn headers_failure() {
863        let mut mock = MockTokenProvider::new();
864        mock.expect_token()
865            .times(1)
866            .return_once(|| Err(errors::non_retryable_from_str("fail")));
867
868        let uc = UserCredentials {
869            token_provider: TokenCache::new(mock),
870            quota_project_id: None,
871        };
872        assert!(uc.headers(Extensions::new()).await.is_err());
873    }
874
875    #[tokio::test]
876    async fn headers_with_quota_project_success() -> TestResult {
877        let token = Token {
878            token: "test-token".to_string(),
879            token_type: "Bearer".to_string(),
880            expires_at: None,
881            metadata: None,
882        };
883
884        let mut mock = MockTokenProvider::new();
885        mock.expect_token().times(1).return_once(|| Ok(token));
886
887        let uc = UserCredentials {
888            token_provider: TokenCache::new(mock),
889            quota_project_id: Some("test-project".to_string()),
890        };
891
892        let headers = get_headers_from_cache(uc.headers(Extensions::new()).await.unwrap())?;
893        let token = headers.get(AUTHORIZATION).unwrap();
894        let quota_project_header = headers.get(QUOTA_PROJECT_KEY).unwrap();
895
896        assert_eq!(headers.len(), 2, "{headers:?}");
897        assert_eq!(token, HeaderValue::from_static("Bearer test-token"));
898        assert!(token.is_sensitive());
899        assert_eq!(
900            quota_project_header,
901            HeaderValue::from_static("test-project")
902        );
903        assert!(!quota_project_header.is_sensitive());
904        Ok(())
905    }
906
907    #[test]
908    fn oauth2_request_serde() {
909        let request = Oauth2RefreshRequest {
910            grant_type: RefreshGrantType::RefreshToken,
911            client_id: "test-client-id".to_string(),
912            client_secret: "test-client-secret".to_string(),
913            refresh_token: "test-refresh-token".to_string(),
914            scopes: Some("scope1 scope2".to_string()),
915        };
916
917        let json = serde_json::to_value(&request).unwrap();
918        let expected = serde_json::json!({
919            "grant_type": "refresh_token",
920            "client_id": "test-client-id",
921            "client_secret": "test-client-secret",
922            "refresh_token": "test-refresh-token",
923            "scopes": "scope1 scope2",
924        });
925        assert_eq!(json, expected);
926        let roundtrip = serde_json::from_value::<Oauth2RefreshRequest>(json).unwrap();
927        assert_eq!(request, roundtrip);
928    }
929
930    #[test]
931    fn oauth2_response_serde_full() {
932        let response = Oauth2RefreshResponse {
933            access_token: "test-access-token".to_string(),
934            id_token: None,
935            scope: Some("scope1 scope2".to_string()),
936            expires_in: Some(3600),
937            token_type: "test-token-type".to_string(),
938            refresh_token: Some("test-refresh-token".to_string()),
939        };
940
941        let json = serde_json::to_value(&response).unwrap();
942        let expected = serde_json::json!({
943            "access_token": "test-access-token",
944            "scope": "scope1 scope2",
945            "expires_in": 3600,
946            "token_type": "test-token-type",
947            "refresh_token": "test-refresh-token"
948        });
949        assert_eq!(json, expected);
950        let roundtrip = serde_json::from_value::<Oauth2RefreshResponse>(json).unwrap();
951        assert_eq!(response, roundtrip);
952    }
953
954    #[test]
955    fn oauth2_response_serde_partial() {
956        let response = Oauth2RefreshResponse {
957            access_token: "test-access-token".to_string(),
958            id_token: None,
959            scope: None,
960            expires_in: None,
961            token_type: "test-token-type".to_string(),
962            refresh_token: None,
963        };
964
965        let json = serde_json::to_value(&response).unwrap();
966        let expected = serde_json::json!({
967            "access_token": "test-access-token",
968            "token_type": "test-token-type",
969        });
970        assert_eq!(json, expected);
971        let roundtrip = serde_json::from_value::<Oauth2RefreshResponse>(json).unwrap();
972        assert_eq!(response, roundtrip);
973    }
974
975    pub(crate) fn check_request(
976        request: &Oauth2RefreshRequest,
977        expected_scopes: Option<String>,
978    ) -> bool {
979        request.client_id == "test-client-id"
980            && request.client_secret == "test-client-secret"
981            && request.refresh_token == "test-refresh-token"
982            && request.grant_type == RefreshGrantType::RefreshToken
983            && request.scopes == expected_scopes
984    }
985
986    #[tokio::test(start_paused = true)]
987    async fn token_provider_full() -> TestResult {
988        let server = Server::run();
989        let response = Oauth2RefreshResponse {
990            access_token: "test-access-token".to_string(),
991            id_token: None,
992            expires_in: Some(3600),
993            refresh_token: Some("test-refresh-token".to_string()),
994            scope: Some("scope1 scope2".to_string()),
995            token_type: "test-token-type".to_string(),
996        };
997        server.expect(
998            Expectation::matching(all_of![
999                request::path("/token"),
1000                request::body(json_decoded(|req: &Oauth2RefreshRequest| {
1001                    check_request(req, Some("scope1 scope2".to_string()))
1002                }))
1003            ])
1004            .respond_with(json_encoded(response)),
1005        );
1006
1007        let tp = UserTokenProvider {
1008            client_id: "test-client-id".to_string(),
1009            client_secret: "test-client-secret".to_string(),
1010            refresh_token: "test-refresh-token".to_string(),
1011            endpoint: server.url("/token").to_string(),
1012            scopes: Some("scope1 scope2".to_string()),
1013            source: UserTokenSource::AccessToken,
1014        };
1015        let now = Instant::now();
1016        let token = tp.token().await?;
1017        assert_eq!(token.token, "test-access-token");
1018        assert_eq!(token.token_type, "test-token-type");
1019        assert!(
1020            token
1021                .expires_at
1022                .is_some_and(|d| d == now + Duration::from_secs(3600)),
1023            "now: {:?}, expires_at: {:?}",
1024            now,
1025            token.expires_at
1026        );
1027
1028        Ok(())
1029    }
1030
1031    #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
1032    async fn credential_full_with_quota_project() -> TestResult {
1033        let server = Server::run();
1034        let response = Oauth2RefreshResponse {
1035            access_token: "test-access-token".to_string(),
1036            id_token: None,
1037            expires_in: Some(3600),
1038            refresh_token: Some("test-refresh-token".to_string()),
1039            scope: None,
1040            token_type: "test-token-type".to_string(),
1041        };
1042        server.expect(
1043            Expectation::matching(all_of![
1044                request::path("/token"),
1045                request::body(json_decoded(|req: &Oauth2RefreshRequest| {
1046                    check_request(req, None)
1047                }))
1048            ])
1049            .respond_with(json_encoded(response)),
1050        );
1051
1052        let authorized_user = serde_json::json!({
1053            "client_id": "test-client-id",
1054            "client_secret": "test-client-secret",
1055            "refresh_token": "test-refresh-token",
1056            "type": "authorized_user",
1057            "token_uri": server.url("/token").to_string(),
1058        });
1059        let cred = Builder::new(authorized_user)
1060            .with_quota_project_id("test-project")
1061            .build()?;
1062
1063        let headers = get_headers_from_cache(cred.headers(Extensions::new()).await.unwrap())?;
1064        let token = headers.get(AUTHORIZATION).unwrap();
1065        let quota_project_header = headers.get(QUOTA_PROJECT_KEY).unwrap();
1066
1067        assert_eq!(headers.len(), 2, "{headers:?}");
1068        assert_eq!(
1069            token,
1070            HeaderValue::from_static("test-token-type test-access-token")
1071        );
1072        assert!(token.is_sensitive());
1073        assert_eq!(
1074            quota_project_header,
1075            HeaderValue::from_static("test-project")
1076        );
1077        assert!(!quota_project_header.is_sensitive());
1078
1079        Ok(())
1080    }
1081
1082    #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
1083    async fn creds_from_json_custom_uri_with_caching() -> TestResult {
1084        let mut server = Server::run();
1085        let response = Oauth2RefreshResponse {
1086            access_token: "test-access-token".to_string(),
1087            id_token: None,
1088            expires_in: Some(3600),
1089            refresh_token: Some("test-refresh-token".to_string()),
1090            scope: Some("scope1 scope2".to_string()),
1091            token_type: "test-token-type".to_string(),
1092        };
1093        server.expect(
1094            Expectation::matching(all_of![
1095                request::path("/token"),
1096                request::body(json_decoded(|req: &Oauth2RefreshRequest| {
1097                    check_request(req, Some("scope1 scope2".to_string()))
1098                }))
1099            ])
1100            .times(1)
1101            .respond_with(json_encoded(response)),
1102        );
1103
1104        let json = serde_json::json!({
1105            "client_id": "test-client-id",
1106            "client_secret": "test-client-secret",
1107            "refresh_token": "test-refresh-token",
1108            "type": "authorized_user",
1109            "universe_domain": "googleapis.com",
1110            "quota_project_id": "test-project",
1111            "token_uri": server.url("/token").to_string(),
1112        });
1113
1114        let cred = Builder::new(json)
1115            .with_scopes(vec!["scope1", "scope2"])
1116            .build()?;
1117
1118        let token = get_token_from_headers(cred.headers(Extensions::new()).await?);
1119        assert_eq!(token.unwrap(), "test-access-token");
1120
1121        let token = get_token_from_headers(cred.headers(Extensions::new()).await?);
1122        assert_eq!(token.unwrap(), "test-access-token");
1123
1124        server.verify_and_clear();
1125
1126        Ok(())
1127    }
1128
1129    #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
1130    async fn credential_provider_partial() -> TestResult {
1131        let server = Server::run();
1132        let response = Oauth2RefreshResponse {
1133            access_token: "test-access-token".to_string(),
1134            id_token: None,
1135            expires_in: None,
1136            refresh_token: None,
1137            scope: None,
1138            token_type: "test-token-type".to_string(),
1139        };
1140        server.expect(
1141            Expectation::matching(all_of![
1142                request::path("/token"),
1143                request::body(json_decoded(|req: &Oauth2RefreshRequest| {
1144                    check_request(req, None)
1145                }))
1146            ])
1147            .respond_with(json_encoded(response)),
1148        );
1149
1150        let authorized_user = serde_json::json!({
1151            "client_id": "test-client-id",
1152            "client_secret": "test-client-secret",
1153            "refresh_token": "test-refresh-token",
1154            "type": "authorized_user",
1155            "token_uri": server.url("/token").to_string()
1156        });
1157
1158        let uc = Builder::new(authorized_user).build()?;
1159        let headers = uc.headers(Extensions::new()).await?;
1160        assert_eq!(
1161            get_token_from_headers(headers.clone()).unwrap(),
1162            "test-access-token"
1163        );
1164        assert_eq!(
1165            get_token_type_from_headers(headers).unwrap(),
1166            "test-token-type"
1167        );
1168
1169        Ok(())
1170    }
1171
1172    #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
1173    async fn credential_provider_with_token_uri() -> TestResult {
1174        let server = Server::run();
1175        let response = Oauth2RefreshResponse {
1176            access_token: "test-access-token".to_string(),
1177            id_token: None,
1178            expires_in: None,
1179            refresh_token: None,
1180            scope: None,
1181            token_type: "test-token-type".to_string(),
1182        };
1183        server.expect(
1184            Expectation::matching(all_of![
1185                request::path("/token"),
1186                request::body(json_decoded(|req: &Oauth2RefreshRequest| {
1187                    check_request(req, None)
1188                }))
1189            ])
1190            .respond_with(json_encoded(response)),
1191        );
1192
1193        let authorized_user = serde_json::json!({
1194            "client_id": "test-client-id",
1195            "client_secret": "test-client-secret",
1196            "refresh_token": "test-refresh-token",
1197            "type": "authorized_user",
1198            "token_uri": "test-endpoint"
1199        });
1200
1201        let uc = Builder::new(authorized_user)
1202            .with_token_uri(server.url("/token").to_string())
1203            .build()?;
1204        let headers = uc.headers(Extensions::new()).await?;
1205        assert_eq!(
1206            get_token_from_headers(headers.clone()).unwrap(),
1207            "test-access-token"
1208        );
1209        assert_eq!(
1210            get_token_type_from_headers(headers).unwrap(),
1211            "test-token-type"
1212        );
1213
1214        Ok(())
1215    }
1216
1217    #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
1218    async fn credential_provider_with_scopes() -> TestResult {
1219        let server = Server::run();
1220        let response = Oauth2RefreshResponse {
1221            access_token: "test-access-token".to_string(),
1222            id_token: None,
1223            expires_in: None,
1224            refresh_token: None,
1225            scope: Some("scope1 scope2".to_string()),
1226            token_type: "test-token-type".to_string(),
1227        };
1228        server.expect(
1229            Expectation::matching(all_of![
1230                request::path("/token"),
1231                request::body(json_decoded(|req: &Oauth2RefreshRequest| {
1232                    check_request(req, Some("scope1 scope2".to_string()))
1233                }))
1234            ])
1235            .respond_with(json_encoded(response)),
1236        );
1237
1238        let authorized_user = serde_json::json!({
1239            "client_id": "test-client-id",
1240            "client_secret": "test-client-secret",
1241            "refresh_token": "test-refresh-token",
1242            "type": "authorized_user",
1243            "token_uri": "test-endpoint"
1244        });
1245
1246        let uc = Builder::new(authorized_user)
1247            .with_token_uri(server.url("/token").to_string())
1248            .with_scopes(vec!["scope1", "scope2"])
1249            .build()?;
1250        let headers = uc.headers(Extensions::new()).await?;
1251        assert_eq!(
1252            get_token_from_headers(headers.clone()).unwrap(),
1253            "test-access-token"
1254        );
1255        assert_eq!(
1256            get_token_type_from_headers(headers).unwrap(),
1257            "test-token-type"
1258        );
1259
1260        Ok(())
1261    }
1262
1263    #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
1264    async fn credential_provider_retryable_error() -> TestResult {
1265        let server = Server::run();
1266        server
1267            .expect(Expectation::matching(request::path("/token")).respond_with(status_code(503)));
1268
1269        let authorized_user = serde_json::json!({
1270            "client_id": "test-client-id",
1271            "client_secret": "test-client-secret",
1272            "refresh_token": "test-refresh-token",
1273            "type": "authorized_user",
1274            "token_uri": server.url("/token").to_string()
1275        });
1276
1277        let uc = Builder::new(authorized_user).build()?;
1278        let err = uc.headers(Extensions::new()).await.unwrap_err();
1279        let original_err = find_source_error::<CredentialsError>(&err).unwrap();
1280        assert!(original_err.is_transient());
1281
1282        let source = find_source_error::<reqwest::Error>(&err);
1283        assert!(
1284            matches!(source, Some(e) if e.status() == Some(StatusCode::SERVICE_UNAVAILABLE)),
1285            "{err:?}"
1286        );
1287
1288        Ok(())
1289    }
1290
1291    #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
1292    async fn token_provider_nonretryable_error() -> TestResult {
1293        let server = Server::run();
1294        server
1295            .expect(Expectation::matching(request::path("/token")).respond_with(status_code(401)));
1296
1297        let authorized_user = serde_json::json!({
1298            "client_id": "test-client-id",
1299            "client_secret": "test-client-secret",
1300            "refresh_token": "test-refresh-token",
1301            "type": "authorized_user",
1302            "token_uri": server.url("/token").to_string()
1303        });
1304
1305        let uc = Builder::new(authorized_user).build()?;
1306        let err = uc.headers(Extensions::new()).await.unwrap_err();
1307        let original_err = find_source_error::<CredentialsError>(&err).unwrap();
1308        assert!(!original_err.is_transient());
1309
1310        let source = find_source_error::<reqwest::Error>(&err);
1311        assert!(
1312            matches!(source, Some(e) if e.status() == Some(StatusCode::UNAUTHORIZED)),
1313            "{err:?}"
1314        );
1315
1316        Ok(())
1317    }
1318
1319    #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
1320    async fn token_provider_malformed_response_is_nonretryable() -> TestResult {
1321        let server = Server::run();
1322        server.expect(
1323            Expectation::matching(request::path("/token"))
1324                .respond_with(json_encoded("bad json".to_string())),
1325        );
1326
1327        let authorized_user = serde_json::json!({
1328            "client_id": "test-client-id",
1329            "client_secret": "test-client-secret",
1330            "refresh_token": "test-refresh-token",
1331            "type": "authorized_user",
1332            "token_uri": server.url("/token").to_string()
1333        });
1334
1335        let uc = Builder::new(authorized_user).build()?;
1336        let e = uc.headers(Extensions::new()).await.err().unwrap();
1337        assert!(!e.is_transient(), "{e}");
1338
1339        Ok(())
1340    }
1341
1342    #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
1343    async fn builder_malformed_authorized_json_nonretryable() -> TestResult {
1344        let authorized_user = serde_json::json!({
1345            "client_secret": "test-client-secret",
1346            "refresh_token": "test-refresh-token",
1347            "type": "authorized_user",
1348        });
1349
1350        let e = Builder::new(authorized_user).build().unwrap_err();
1351        assert!(e.is_parsing(), "{e}");
1352
1353        Ok(())
1354    }
1355}
1356
1357#[cfg(all(test, google_cloud_unstable_id_token))]
1358mod unstable_tests {
1359    use super::tests::*;
1360    use super::*;
1361    use crate::credentials::tests::find_source_error;
1362    use http::StatusCode;
1363    use httptest::matchers::{all_of, json_decoded, request};
1364    use httptest::responders::{json_encoded, status_code};
1365    use httptest::{Expectation, Server};
1366
1367    type TestResult = anyhow::Result<()>;
1368
1369    #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
1370    async fn id_token_success() -> TestResult {
1371        let server = Server::run();
1372        let response = Oauth2RefreshResponse {
1373            access_token: "test-access-token".to_string(),
1374            id_token: Some("test-id-token".to_string()),
1375            expires_in: Some(3600),
1376            refresh_token: Some("test-refresh-token".to_string()),
1377            scope: None,
1378            token_type: "Bearer".to_string(),
1379        };
1380        server.expect(
1381            Expectation::matching(all_of![
1382                request::path("/token"),
1383                request::body(json_decoded(|req: &Oauth2RefreshRequest| {
1384                    check_request(req, None)
1385                }))
1386            ])
1387            .respond_with(json_encoded(response)),
1388        );
1389
1390        let authorized_user = authorized_user_json(server.url("/token").to_string());
1391        let creds = super::idtoken::Builder::new(authorized_user).build()?;
1392        let id_token = creds.id_token().await?;
1393        assert_eq!(id_token, "test-id-token");
1394        Ok(())
1395    }
1396
1397    #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
1398    async fn id_token_missing_id_token_in_response() -> TestResult {
1399        let server = Server::run();
1400        let response = Oauth2RefreshResponse {
1401            access_token: "test-access-token".to_string(),
1402            id_token: None, // Missing ID token
1403            expires_in: Some(3600),
1404            refresh_token: Some("test-refresh-token".to_string()),
1405            scope: None,
1406            token_type: "Bearer".to_string(),
1407        };
1408        server.expect(
1409            Expectation::matching(all_of![
1410                request::path("/token"),
1411                request::body(json_decoded(|req: &Oauth2RefreshRequest| {
1412                    check_request(req, None)
1413                }))
1414            ])
1415            .respond_with(json_encoded(response)),
1416        );
1417
1418        let authorized_user = authorized_user_json(server.url("/token").to_string());
1419        let creds = super::idtoken::Builder::new(authorized_user).build()?;
1420        let err = creds.id_token().await.unwrap_err();
1421        assert!(!err.is_transient());
1422        assert!(err.to_string().contains(MISSING_ID_TOKEN_MSG));
1423        Ok(())
1424    }
1425
1426    #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
1427    async fn id_token_builder_malformed_authorized_json_nonretryable() -> TestResult {
1428        let authorized_user = serde_json::json!({
1429            "client_secret": "test-client-secret",
1430            "refresh_token": "test-refresh-token",
1431            "type": "authorized_user",
1432        });
1433
1434        let e = super::idtoken::Builder::new(authorized_user)
1435            .build()
1436            .unwrap_err();
1437        assert!(e.is_parsing(), "{e}");
1438        Ok(())
1439    }
1440
1441    #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
1442    async fn id_token_retryable_error() -> TestResult {
1443        let server = Server::run();
1444        server
1445            .expect(Expectation::matching(request::path("/token")).respond_with(status_code(503)));
1446
1447        let authorized_user = authorized_user_json(server.url("/token").to_string());
1448        let creds = super::idtoken::Builder::new(authorized_user).build()?;
1449        let err = creds.id_token().await.unwrap_err();
1450        assert!(err.is_transient());
1451
1452        let source = find_source_error::<reqwest::Error>(&err);
1453        assert!(
1454            matches!(source, Some(e) if e.status() == Some(StatusCode::SERVICE_UNAVAILABLE)),
1455            "{err:?}"
1456        );
1457        Ok(())
1458    }
1459
1460    #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
1461    async fn id_token_nonretryable_error() -> TestResult {
1462        let server = Server::run();
1463        server
1464            .expect(Expectation::matching(request::path("/token")).respond_with(status_code(401)));
1465
1466        let authorized_user = authorized_user_json(server.url("/token").to_string());
1467        let creds = super::idtoken::Builder::new(authorized_user).build()?;
1468        let err = creds.id_token().await.unwrap_err();
1469        assert!(!err.is_transient());
1470
1471        let source = find_source_error::<reqwest::Error>(&err);
1472        assert!(
1473            matches!(source, Some(e) if e.status() == Some(StatusCode::UNAUTHORIZED)),
1474            "{err:?}"
1475        );
1476        Ok(())
1477    }
1478}