google_cloud_auth/credentials/
mds.rs

1// Copyright 2024 Google LLC
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     https://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15//! [Metadata Service] Credentials type.
16//!
17//! Google Cloud environments such as [Google Compute Engine (GCE)][gce-link],
18//! [Google Kubernetes Engine (GKE)][gke-link], or [Cloud Run] provide a metadata service.
19//! This is a local service to the VM (or pod) which (as the name implies) provides
20//! metadata information about the VM. The service also provides access
21//! tokens associated with the [default service account] for the corresponding
22//! VM.
23//!
24//! The default host name of the metadata service is `metadata.google.internal`.
25//! If you would like to use a different hostname, you can set it using the
26//! `GCE_METADATA_HOST` environment variable.
27//!
28//! You can use this access token to securely authenticate with Google Cloud,
29//! without having to download secrets or other credentials. The types in this
30//! module allow you to retrieve these access tokens, and can be used with
31//! the Google Cloud client libraries for Rust.
32//!
33//! ## Example: Creating credentials with a custom quota project
34//!
35//! ```
36//! # use google_cloud_auth::credentials::mds::Builder;
37//! # use google_cloud_auth::credentials::Credentials;
38//! # use http::Extensions;
39//! # tokio_test::block_on(async {
40//! let credentials: Credentials = Builder::default()
41//!     .with_quota_project_id("my-quota-project")
42//!     .build()?;
43//! let headers = credentials.headers(Extensions::new()).await?;
44//! println!("Headers: {headers:?}");
45//! # Ok::<(), anyhow::Error>(())
46//! # });
47//! ```
48//!
49//! ## Example: Creating credentials with custom retry behavior
50//!
51//! ```
52//! # use google_cloud_auth::credentials::mds::Builder;
53//! # use google_cloud_auth::credentials::Credentials;
54//! # use http::Extensions;
55//! # use std::time::Duration;
56//! # tokio_test::block_on(async {
57//! use gax::retry_policy::{AlwaysRetry, RetryPolicyExt};
58//! use gax::exponential_backoff::ExponentialBackoff;
59//! let backoff = ExponentialBackoff::default();
60//! let credentials: Credentials = Builder::default()
61//!     .with_retry_policy(AlwaysRetry.with_attempt_limit(3))
62//!     .with_backoff_policy(backoff)
63//!     .build()?;
64//! let headers = credentials.headers(Extensions::new()).await?;
65//! println!("Headers: {headers:?}");
66//! # Ok::<(), anyhow::Error>(())
67//! # });
68//! ```
69//!
70//! [Application Default Credentials]: https://cloud.google.com/docs/authentication/application-default-credentials
71//! [Cloud Run]: https://cloud.google.com/run
72//! [default service account]: https://cloud.google.com/iam/docs/service-account-types#default
73//! [gce-link]: https://cloud.google.com/products/compute
74//! [gke-link]: https://cloud.google.com/kubernetes-engine
75//! [Metadata Service]: https://cloud.google.com/compute/docs/metadata/overview
76
77use crate::credentials::dynamic::CredentialsProvider;
78use crate::credentials::{CacheableResource, Credentials, DEFAULT_UNIVERSE_DOMAIN};
79use crate::errors::CredentialsError;
80use crate::headers_util::build_cacheable_headers;
81use crate::retry::{Builder as RetryTokenProviderBuilder, TokenProviderWithRetry};
82use crate::token::{CachedTokenProvider, Token, TokenProvider};
83use crate::token_cache::TokenCache;
84use crate::{BuildResult, Result};
85use async_trait::async_trait;
86use bon::Builder;
87use gax::backoff_policy::BackoffPolicyArg;
88use gax::retry_policy::RetryPolicyArg;
89use gax::retry_throttler::RetryThrottlerArg;
90use http::{Extensions, HeaderMap, HeaderValue};
91use reqwest::Client;
92use std::default::Default;
93use std::sync::Arc;
94use std::time::Duration;
95use tokio::time::Instant;
96
97const METADATA_FLAVOR_VALUE: &str = "Google";
98const METADATA_FLAVOR: &str = "metadata-flavor";
99const METADATA_ROOT: &str = "http://metadata.google.internal";
100const MDS_DEFAULT_URI: &str = "/computeMetadata/v1/instance/service-accounts/default";
101const GCE_METADATA_HOST_ENV_VAR: &str = "GCE_METADATA_HOST";
102// TODO(#2235) - Improve this message by talking about retries when really running with MDS
103const MDS_NOT_FOUND_ERROR: &str = concat!(
104    "Could not fetch an auth token to authenticate with Google Cloud. ",
105    "The most common reason for this problem is that you are not running in a Google Cloud Environment ",
106    "and you have not configured local credentials for development and testing. ",
107    "To setup local credentials, run `gcloud auth application-default login`. ",
108    "More information on how to authenticate client libraries can be found at https://cloud.google.com/docs/authentication/client-libraries"
109);
110
111#[derive(Debug)]
112struct MDSCredentials<T>
113where
114    T: CachedTokenProvider,
115{
116    quota_project_id: Option<String>,
117    universe_domain: Option<String>,
118    token_provider: T,
119}
120
121/// Creates [Credentials] instances backed by the [Metadata Service].
122///
123/// While the Google Cloud client libraries for Rust default to credentials
124/// backed by the metadata service, some applications may need to:
125/// * Customize the metadata service credentials in some way
126/// * Bypass the [Application Default Credentials] lookup and only
127///   use the metadata server credentials
128/// * Use the credentials directly outside the client libraries
129///
130/// [Application Default Credentials]: https://cloud.google.com/docs/authentication/application-default-credentials
131/// [Metadata Service]: https://cloud.google.com/compute/docs/metadata/overview
132#[derive(Debug, Default)]
133pub struct Builder {
134    endpoint: Option<String>,
135    quota_project_id: Option<String>,
136    scopes: Option<Vec<String>>,
137    universe_domain: Option<String>,
138    created_by_adc: bool,
139    retry_builder: RetryTokenProviderBuilder,
140}
141
142impl Builder {
143    /// Sets the endpoint for this credentials.
144    ///
145    /// A trailing slash is significant, so specify the base URL without a trailing  
146    /// slash. If not set, the credentials use `http://metadata.google.internal`.
147    ///
148    /// # Example
149    /// ```
150    /// # use google_cloud_auth::credentials::mds::Builder;
151    /// # tokio_test::block_on(async {
152    /// let credentials = Builder::default()
153    ///     .with_endpoint("https://metadata.google.foobar")
154    ///     .build();
155    /// # });
156    /// ```
157    pub fn with_endpoint<S: Into<String>>(mut self, endpoint: S) -> Self {
158        self.endpoint = Some(endpoint.into());
159        self
160    }
161
162    /// Set the [quota project] for this credentials.
163    ///
164    /// In some services, you can use a service account in
165    /// one project for authentication and authorization, and charge
166    /// the usage to a different project. This may require that the
167    /// service account has `serviceusage.services.use` permissions on the quota project.
168    ///
169    /// [quota project]: https://cloud.google.com/docs/quotas/quota-project
170    pub fn with_quota_project_id<S: Into<String>>(mut self, quota_project_id: S) -> Self {
171        self.quota_project_id = Some(quota_project_id.into());
172        self
173    }
174
175    /// Sets the universe domain for this credentials.
176    ///
177    /// Client libraries use `universe_domain` to determine
178    /// the API endpoints to use for making requests.
179    /// If not set, then credentials use `${service}.googleapis.com`,
180    /// otherwise they use `${service}.${universe_domain}.
181    pub fn with_universe_domain<S: Into<String>>(mut self, universe_domain: S) -> Self {
182        self.universe_domain = Some(universe_domain.into());
183        self
184    }
185
186    /// Sets the [scopes] for this credentials.
187    ///
188    /// Metadata server issues tokens based on the requested scopes.
189    /// If no scopes are specified, the credentials defaults to all
190    /// scopes configured for the [default service account] on the instance.
191    ///
192    /// [default service account]: https://cloud.google.com/iam/docs/service-account-types#default
193    /// [scopes]: https://developers.google.com/identity/protocols/oauth2/scopes
194    pub fn with_scopes<I, S>(mut self, scopes: I) -> Self
195    where
196        I: IntoIterator<Item = S>,
197        S: Into<String>,
198    {
199        self.scopes = Some(scopes.into_iter().map(|s| s.into()).collect());
200        self
201    }
202
203    /// Configure the retry policy for fetching tokens.
204    ///
205    /// The retry policy controls how to handle retries, and sets limits on
206    /// the number of attempts or the total time spent retrying.
207    ///
208    /// ```
209    /// # use google_cloud_auth::credentials::mds::Builder;
210    /// # tokio_test::block_on(async {
211    /// use gax::retry_policy::{AlwaysRetry, RetryPolicyExt};
212    /// let credentials = Builder::default()
213    ///     .with_retry_policy(AlwaysRetry.with_attempt_limit(3))
214    ///     .build();
215    /// # });
216    /// ```
217    pub fn with_retry_policy<V: Into<RetryPolicyArg>>(mut self, v: V) -> Self {
218        self.retry_builder = self.retry_builder.with_retry_policy(v.into());
219        self
220    }
221
222    /// Configure the retry backoff policy.
223    ///
224    /// The backoff policy controls how long to wait in between retry attempts.
225    ///
226    /// ```
227    /// # use google_cloud_auth::credentials::mds::Builder;
228    /// # use std::time::Duration;
229    /// # tokio_test::block_on(async {
230    /// use gax::exponential_backoff::ExponentialBackoff;
231    /// let policy = ExponentialBackoff::default();
232    /// let credentials = Builder::default()
233    ///     .with_backoff_policy(policy)
234    ///     .build();
235    /// # });
236    /// ```
237    pub fn with_backoff_policy<V: Into<BackoffPolicyArg>>(mut self, v: V) -> Self {
238        self.retry_builder = self.retry_builder.with_backoff_policy(v.into());
239        self
240    }
241
242    /// Configure the retry throttler.
243    ///
244    /// Advanced applications may want to configure a retry throttler to
245    /// [Address Cascading Failures] and when [Handling Overload] conditions.
246    /// The authentication library throttles its retry loop, using a policy to
247    /// control the throttling algorithm. Use this method to fine tune or
248    /// customize the default retry throttler.
249    ///
250    /// [Handling Overload]: https://sre.google/sre-book/handling-overload/
251    /// [Address Cascading Failures]: https://sre.google/sre-book/addressing-cascading-failures/
252    ///
253    /// ```
254    /// # use google_cloud_auth::credentials::mds::Builder;
255    /// # tokio_test::block_on(async {
256    /// use gax::retry_throttler::AdaptiveThrottler;
257    /// let credentials = Builder::default()
258    ///     .with_retry_throttler(AdaptiveThrottler::default())
259    ///     .build();
260    /// # });
261    /// ```
262    pub fn with_retry_throttler<V: Into<RetryThrottlerArg>>(mut self, v: V) -> Self {
263        self.retry_builder = self.retry_builder.with_retry_throttler(v.into());
264        self
265    }
266
267    // This method is used to build mds credentials from ADC
268    pub(crate) fn from_adc() -> Self {
269        Self {
270            created_by_adc: true,
271            ..Default::default()
272        }
273    }
274
275    fn build_token_provider(self) -> TokenProviderWithRetry<MDSAccessTokenProvider> {
276        let final_endpoint: String;
277        let endpoint_overridden: bool;
278
279        // Determine the endpoint and whether it was overridden
280        if let Ok(host_from_env) = std::env::var(GCE_METADATA_HOST_ENV_VAR) {
281            // Check GCE_METADATA_HOST environment variable first
282            final_endpoint = format!("http://{host_from_env}");
283            endpoint_overridden = true;
284        } else if let Some(builder_endpoint) = self.endpoint {
285            // Else, check if an endpoint was provided to the mds::Builder
286            final_endpoint = builder_endpoint;
287            endpoint_overridden = true;
288        } else {
289            // Else, use the default metadata root
290            final_endpoint = METADATA_ROOT.to_string();
291            endpoint_overridden = false;
292        };
293
294        let tp = MDSAccessTokenProvider::builder()
295            .endpoint(final_endpoint)
296            .maybe_scopes(self.scopes)
297            .endpoint_overridden(endpoint_overridden)
298            .created_by_adc(self.created_by_adc)
299            .build();
300        self.retry_builder.build(tp)
301    }
302
303    /// Returns a [Credentials] instance with the configured settings.
304    pub fn build(self) -> BuildResult<Credentials> {
305        let mdsc = MDSCredentials {
306            quota_project_id: self.quota_project_id.clone(),
307            universe_domain: self.universe_domain.clone(),
308            token_provider: TokenCache::new(self.build_token_provider()),
309        };
310        Ok(Credentials {
311            inner: Arc::new(mdsc),
312        })
313    }
314}
315
316#[async_trait::async_trait]
317impl<T> CredentialsProvider for MDSCredentials<T>
318where
319    T: CachedTokenProvider,
320{
321    async fn headers(&self, extensions: Extensions) -> Result<CacheableResource<HeaderMap>> {
322        let cached_token = self.token_provider.token(extensions).await?;
323        build_cacheable_headers(&cached_token, &self.quota_project_id)
324    }
325
326    async fn universe_domain(&self) -> Option<String> {
327        if self.universe_domain.is_some() {
328            return self.universe_domain.clone();
329        }
330        return Some(DEFAULT_UNIVERSE_DOMAIN.to_string());
331    }
332}
333
334#[derive(Clone, Debug, PartialEq, serde::Deserialize, serde::Serialize)]
335struct MDSTokenResponse {
336    access_token: String,
337    #[serde(skip_serializing_if = "Option::is_none")]
338    expires_in: Option<u64>,
339    token_type: String,
340}
341
342#[derive(Debug, Clone, Default, Builder)]
343struct MDSAccessTokenProvider {
344    #[builder(into)]
345    scopes: Option<Vec<String>>,
346    #[builder(into)]
347    endpoint: String,
348    endpoint_overridden: bool,
349    created_by_adc: bool,
350}
351
352impl MDSAccessTokenProvider {
353    // During ADC, if no credentials are found in the well-known location and the GOOGLE_APPLICATION_CREDENTIALS
354    // environment variable is not set, we default to MDS credentials without checking if the code is really
355    // running in an environment with MDS. To help users who got to this state because of lack of credentials
356    // setup on their machines, we provide a detailed error message to them talking about local setup and other
357    // auth mechanisms available to them.
358    // If the endpoint is overridden, even if ADC was used to create the MDS credentials, we do not give a detailed
359    // error message because they deliberately wanted to use an MDS.
360    fn error_message(&self) -> &str {
361        if self.use_adc_message() {
362            MDS_NOT_FOUND_ERROR
363        } else {
364            "failed to fetch token"
365        }
366    }
367
368    fn use_adc_message(&self) -> bool {
369        self.created_by_adc && !self.endpoint_overridden
370    }
371}
372
373#[async_trait]
374impl TokenProvider for MDSAccessTokenProvider {
375    async fn token(&self) -> Result<Token> {
376        let client = Client::new();
377        let request = client
378            .get(format!("{}{}/token", self.endpoint, MDS_DEFAULT_URI))
379            .header(
380                METADATA_FLAVOR,
381                HeaderValue::from_static(METADATA_FLAVOR_VALUE),
382            );
383        // Use the `scopes` option if set, otherwise let the MDS use the default
384        // scopes.
385        let scopes = self.scopes.as_ref().map(|v| v.join(","));
386        let request = scopes
387            .into_iter()
388            .fold(request, |r, s| r.query(&[("scopes", s)]));
389
390        // If the connection to MDS was not successful, it is useful to retry when really
391        // running on MDS environments and not useful if there is no MDS. We will mark the error
392        // as retryable and let the retry policy determine whether to retry or not. Whenever we
393        // define a default retry policy, we can skip retrying this case.
394        let response = request
395            .send()
396            .await
397            .map_err(|e| crate::errors::from_http_error(e, self.error_message()))?;
398        // Process the response
399        if !response.status().is_success() {
400            let err = crate::errors::from_http_response(response, self.error_message()).await;
401            return Err(err);
402        }
403        let response = response.json::<MDSTokenResponse>().await.map_err(|e| {
404            // Decoding errors are not transient. Typically they indicate a badly
405            // configured MDS endpoint, or DNS redirecting the request to a random
406            // server, e.g., ISPs that redirect unknown services to HTTP.
407            CredentialsError::from_source(!e.is_decode(), e)
408        })?;
409        let token = Token {
410            token: response.access_token,
411            token_type: response.token_type,
412            expires_at: response
413                .expires_in
414                .map(|d| Instant::now() + Duration::from_secs(d)),
415            metadata: None,
416        };
417        Ok(token)
418    }
419}
420
421#[cfg(test)]
422mod tests {
423    use super::*;
424    use crate::credentials::QUOTA_PROJECT_KEY;
425    use crate::credentials::tests::{
426        find_source_error, get_headers_from_cache, get_mock_auth_retry_policy,
427        get_mock_backoff_policy, get_mock_retry_throttler, get_token_from_headers,
428        get_token_type_from_headers,
429    };
430    use crate::errors;
431    use crate::errors::CredentialsError;
432    use crate::token::tests::MockTokenProvider;
433    use http::HeaderValue;
434    use http::header::AUTHORIZATION;
435    use httptest::cycle;
436    use httptest::matchers::{all_of, contains, request, url_decoded};
437    use httptest::responders::{json_encoded, status_code};
438    use httptest::{Expectation, Server};
439    use reqwest::StatusCode;
440    use scoped_env::ScopedEnv;
441    use serial_test::{parallel, serial};
442    use std::error::Error;
443    use test_case::test_case;
444    use url::Url;
445
446    type TestResult = anyhow::Result<()>;
447
448    #[tokio::test]
449    #[parallel]
450    async fn test_mds_retries_on_transient_failures() -> TestResult {
451        let mut server = Server::run();
452        server.expect(
453            Expectation::matching(request::path(format!("{MDS_DEFAULT_URI}/token")))
454                .times(3)
455                .respond_with(status_code(503)),
456        );
457
458        let provider = Builder::default()
459            .with_endpoint(format!("http://{}", server.addr()))
460            .with_retry_policy(get_mock_auth_retry_policy(3))
461            .with_backoff_policy(get_mock_backoff_policy())
462            .with_retry_throttler(get_mock_retry_throttler())
463            .build_token_provider();
464
465        let err = provider.token().await.unwrap_err();
466        assert!(!err.is_transient());
467        server.verify_and_clear();
468        Ok(())
469    }
470
471    #[tokio::test]
472    #[parallel]
473    async fn test_mds_does_not_retry_on_non_transient_failures() -> TestResult {
474        let mut server = Server::run();
475        server.expect(
476            Expectation::matching(request::path(format!("{MDS_DEFAULT_URI}/token")))
477                .times(1)
478                .respond_with(status_code(401)),
479        );
480
481        let provider = Builder::default()
482            .with_endpoint(format!("http://{}", server.addr()))
483            .with_retry_policy(get_mock_auth_retry_policy(1))
484            .with_backoff_policy(get_mock_backoff_policy())
485            .with_retry_throttler(get_mock_retry_throttler())
486            .build_token_provider();
487
488        let err = provider.token().await.unwrap_err();
489        assert!(!err.is_transient());
490        server.verify_and_clear();
491        Ok(())
492    }
493
494    #[tokio::test]
495    #[parallel]
496    async fn test_mds_retries_for_success() -> TestResult {
497        let mut server = Server::run();
498        let response = MDSTokenResponse {
499            access_token: "test-access-token".to_string(),
500            expires_in: Some(3600),
501            token_type: "test-token-type".to_string(),
502        };
503
504        server.expect(
505            Expectation::matching(request::path(format!("{MDS_DEFAULT_URI}/token")))
506                .times(3)
507                .respond_with(cycle![
508                    status_code(503).body("try-again"),
509                    status_code(503).body("try-again"),
510                    status_code(200)
511                        .append_header("Content-Type", "application/json")
512                        .body(serde_json::to_string(&response).unwrap()),
513                ]),
514        );
515
516        let provider = Builder::default()
517            .with_endpoint(format!("http://{}", server.addr()))
518            .with_retry_policy(get_mock_auth_retry_policy(3))
519            .with_backoff_policy(get_mock_backoff_policy())
520            .with_retry_throttler(get_mock_retry_throttler())
521            .build_token_provider();
522
523        let token = provider.token().await?;
524        assert_eq!(token.token, "test-access-token");
525
526        server.verify_and_clear();
527        Ok(())
528    }
529
530    #[test]
531    fn validate_default_endpoint_urls() {
532        let default_endpoint_address = Url::parse(&format!("{METADATA_ROOT}{MDS_DEFAULT_URI}"));
533        assert!(default_endpoint_address.is_ok());
534
535        let token_endpoint_address = Url::parse(&format!("{METADATA_ROOT}{MDS_DEFAULT_URI}/token"));
536        assert!(token_endpoint_address.is_ok());
537    }
538
539    #[tokio::test]
540    async fn headers_success() -> TestResult {
541        let token = Token {
542            token: "test-token".to_string(),
543            token_type: "Bearer".to_string(),
544            expires_at: None,
545            metadata: None,
546        };
547
548        let mut mock = MockTokenProvider::new();
549        mock.expect_token().times(1).return_once(|| Ok(token));
550
551        let mdsc = MDSCredentials {
552            quota_project_id: None,
553            universe_domain: None,
554            token_provider: TokenCache::new(mock),
555        };
556
557        let mut extensions = Extensions::new();
558        let cached_headers = mdsc.headers(extensions.clone()).await.unwrap();
559        let (headers, entity_tag) = match cached_headers {
560            CacheableResource::New { entity_tag, data } => (data, entity_tag),
561            CacheableResource::NotModified => unreachable!("expecting new headers"),
562        };
563        let token = headers.get(AUTHORIZATION).unwrap();
564        assert_eq!(headers.len(), 1, "{headers:?}");
565        assert_eq!(token, HeaderValue::from_static("Bearer test-token"));
566        assert!(token.is_sensitive());
567
568        extensions.insert(entity_tag);
569
570        let cached_headers = mdsc.headers(extensions).await?;
571
572        match cached_headers {
573            CacheableResource::New { .. } => unreachable!("expecting new headers"),
574            CacheableResource::NotModified => CacheableResource::<HeaderMap>::NotModified,
575        };
576        Ok(())
577    }
578
579    #[tokio::test]
580    async fn headers_failure() {
581        let mut mock = MockTokenProvider::new();
582        mock.expect_token()
583            .times(1)
584            .return_once(|| Err(errors::non_retryable_from_str("fail")));
585
586        let mdsc = MDSCredentials {
587            quota_project_id: None,
588            universe_domain: None,
589            token_provider: TokenCache::new(mock),
590        };
591        assert!(mdsc.headers(Extensions::new()).await.is_err());
592    }
593
594    #[test]
595    fn error_message_with_adc() {
596        let provider = MDSAccessTokenProvider::builder()
597            .endpoint("http://127.0.0.1")
598            .created_by_adc(true)
599            .endpoint_overridden(false)
600            .build();
601
602        let want = MDS_NOT_FOUND_ERROR;
603        let got = provider.error_message();
604        assert!(got.contains(want), "{got}, {provider:?}");
605    }
606
607    #[test_case(false, false)]
608    #[test_case(false, true)]
609    #[test_case(true, true)]
610    fn error_message_without_adc(adc: bool, overridden: bool) {
611        let provider = MDSAccessTokenProvider::builder()
612            .endpoint("http://127.0.0.1")
613            .created_by_adc(adc)
614            .endpoint_overridden(overridden)
615            .build();
616
617        let not_want = MDS_NOT_FOUND_ERROR;
618        let got = provider.error_message();
619        assert!(!got.contains(not_want), "{got}, {provider:?}");
620    }
621
622    #[tokio::test]
623    #[serial]
624    async fn adc_no_mds() -> TestResult {
625        let err = Builder::from_adc()
626            .build_token_provider()
627            .token()
628            .await
629            .unwrap_err();
630
631        let original_err = find_source_error::<CredentialsError>(&err).unwrap();
632        assert!(original_err.is_transient());
633        assert!(
634            original_err.to_string().contains("application-default"),
635            "display={err}, debug={err:?}"
636        );
637        let source = find_source_error::<reqwest::Error>(&err);
638        assert!(matches!(source, Some(e) if e.status().is_none()), "{err:?}");
639
640        Ok(())
641    }
642
643    #[tokio::test]
644    #[serial]
645    async fn adc_overridden_mds() -> TestResult {
646        let _e = ScopedEnv::set(super::GCE_METADATA_HOST_ENV_VAR, "metadata.overridden");
647
648        let err = Builder::from_adc()
649            .build_token_provider()
650            .token()
651            .await
652            .unwrap_err();
653
654        let _e = ScopedEnv::remove(super::GCE_METADATA_HOST_ENV_VAR);
655
656        let original_err = find_source_error::<CredentialsError>(&err).unwrap();
657        assert!(original_err.is_transient());
658        assert!(
659            !original_err.to_string().contains("application-default"),
660            "display={err}, debug={err:?}"
661        );
662        let source = find_source_error::<reqwest::Error>(&err);
663        assert!(matches!(source, Some(e) if e.status().is_none()), "{err:?}");
664
665        Ok(())
666    }
667
668    #[tokio::test]
669    #[serial]
670    async fn builder_no_mds() -> TestResult {
671        let e = Builder::default()
672            .build_token_provider()
673            .token()
674            .await
675            .err()
676            .unwrap();
677
678        let original_err = find_source_error::<CredentialsError>(&e).unwrap();
679        assert!(original_err.is_transient());
680        assert!(
681            !format!("{:?}", original_err.source()).contains("application-default"),
682            "{e:?}"
683        );
684
685        Ok(())
686    }
687
688    #[tokio::test]
689    #[serial]
690    async fn test_gce_metadata_host_env_var() -> TestResult {
691        let server = Server::run();
692        let scopes = ["scope1", "scope2"];
693        let response = MDSTokenResponse {
694            access_token: "test-access-token".to_string(),
695            expires_in: Some(3600),
696            token_type: "test-token-type".to_string(),
697        };
698        server.expect(
699            Expectation::matching(all_of![
700                request::path(format!("{MDS_DEFAULT_URI}/token")),
701                request::query(url_decoded(contains(("scopes", scopes.join(",")))))
702            ])
703            .respond_with(json_encoded(response)),
704        );
705
706        let addr = server.addr().to_string();
707        let _e = ScopedEnv::set(super::GCE_METADATA_HOST_ENV_VAR, &addr);
708        let mdsc = Builder::default()
709            .with_scopes(["scope1", "scope2"])
710            .build()
711            .unwrap();
712        let headers = mdsc.headers(Extensions::new()).await.unwrap();
713        let _e = ScopedEnv::remove(super::GCE_METADATA_HOST_ENV_VAR);
714
715        assert_eq!(
716            get_token_from_headers(headers).unwrap(),
717            "test-access-token"
718        );
719        Ok(())
720    }
721
722    #[tokio::test]
723    #[parallel]
724    async fn headers_success_with_quota_project() -> TestResult {
725        let server = Server::run();
726        let scopes = ["scope1", "scope2"];
727        let response = MDSTokenResponse {
728            access_token: "test-access-token".to_string(),
729            expires_in: Some(3600),
730            token_type: "test-token-type".to_string(),
731        };
732        server.expect(
733            Expectation::matching(all_of![
734                request::path(format!("{MDS_DEFAULT_URI}/token")),
735                request::query(url_decoded(contains(("scopes", scopes.join(",")))))
736            ])
737            .respond_with(json_encoded(response)),
738        );
739
740        let mdsc = Builder::default()
741            .with_scopes(["scope1", "scope2"])
742            .with_endpoint(format!("http://{}", server.addr()))
743            .with_quota_project_id("test-project")
744            .build()?;
745
746        let headers = get_headers_from_cache(mdsc.headers(Extensions::new()).await.unwrap())?;
747        let token = headers.get(AUTHORIZATION).unwrap();
748        let quota_project = headers.get(QUOTA_PROJECT_KEY).unwrap();
749
750        assert_eq!(headers.len(), 2, "{headers:?}");
751        assert_eq!(
752            token,
753            HeaderValue::from_static("test-token-type test-access-token")
754        );
755        assert!(token.is_sensitive());
756        assert_eq!(quota_project, HeaderValue::from_static("test-project"));
757        assert!(!quota_project.is_sensitive());
758
759        Ok(())
760    }
761
762    #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
763    #[parallel]
764    async fn token_caching() -> TestResult {
765        let mut server = Server::run();
766        let scopes = vec!["scope1".to_string()];
767        let response = MDSTokenResponse {
768            access_token: "test-access-token".to_string(),
769            expires_in: Some(3600),
770            token_type: "test-token-type".to_string(),
771        };
772        server.expect(
773            Expectation::matching(all_of![
774                request::path(format!("{MDS_DEFAULT_URI}/token")),
775                request::query(url_decoded(contains(("scopes", scopes.join(",")))))
776            ])
777            .times(1)
778            .respond_with(json_encoded(response)),
779        );
780
781        let mdsc = Builder::default()
782            .with_scopes(scopes)
783            .with_endpoint(format!("http://{}", server.addr()))
784            .build()?;
785        let headers = mdsc.headers(Extensions::new()).await?;
786        assert_eq!(
787            get_token_from_headers(headers).unwrap(),
788            "test-access-token"
789        );
790        let headers = mdsc.headers(Extensions::new()).await?;
791        assert_eq!(
792            get_token_from_headers(headers).unwrap(),
793            "test-access-token"
794        );
795
796        // validate that the inner token provider is called only once
797        server.verify_and_clear();
798
799        Ok(())
800    }
801
802    #[tokio::test(start_paused = true)]
803    #[parallel]
804    async fn token_provider_full() -> TestResult {
805        let server = Server::run();
806        let scopes = vec!["scope1".to_string()];
807        let response = MDSTokenResponse {
808            access_token: "test-access-token".to_string(),
809            expires_in: Some(3600),
810            token_type: "test-token-type".to_string(),
811        };
812        server.expect(
813            Expectation::matching(all_of![
814                request::path(format!("{MDS_DEFAULT_URI}/token")),
815                request::query(url_decoded(contains(("scopes", scopes.join(",")))))
816            ])
817            .respond_with(json_encoded(response)),
818        );
819
820        let token = Builder::default()
821            .with_endpoint(format!("http://{}", server.addr()))
822            .with_scopes(scopes)
823            .build_token_provider()
824            .token()
825            .await?;
826
827        let now = tokio::time::Instant::now();
828        assert_eq!(token.token, "test-access-token");
829        assert_eq!(token.token_type, "test-token-type");
830        assert!(
831            token
832                .expires_at
833                .is_some_and(|d| d >= now + Duration::from_secs(3600))
834        );
835
836        Ok(())
837    }
838
839    #[tokio::test(start_paused = true)]
840    #[parallel]
841    async fn token_provider_full_no_scopes() -> TestResult {
842        let server = Server::run();
843        let response = MDSTokenResponse {
844            access_token: "test-access-token".to_string(),
845            expires_in: Some(3600),
846            token_type: "test-token-type".to_string(),
847        };
848        server.expect(
849            Expectation::matching(request::path(format!("{MDS_DEFAULT_URI}/token")))
850                .respond_with(json_encoded(response)),
851        );
852
853        let token = Builder::default()
854            .with_endpoint(format!("http://{}", server.addr()))
855            .build_token_provider()
856            .token()
857            .await?;
858
859        let now = Instant::now();
860        assert_eq!(token.token, "test-access-token");
861        assert_eq!(token.token_type, "test-token-type");
862        assert!(
863            token
864                .expires_at
865                .is_some_and(|d| d == now + Duration::from_secs(3600))
866        );
867
868        Ok(())
869    }
870
871    #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
872    #[parallel]
873    async fn credential_provider_full() -> TestResult {
874        let server = Server::run();
875        let scopes = vec!["scope1".to_string()];
876        let response = MDSTokenResponse {
877            access_token: "test-access-token".to_string(),
878            expires_in: None,
879            token_type: "test-token-type".to_string(),
880        };
881        server.expect(
882            Expectation::matching(all_of![
883                request::path(format!("{MDS_DEFAULT_URI}/token")),
884                request::query(url_decoded(contains(("scopes", scopes.join(",")))))
885            ])
886            .respond_with(json_encoded(response)),
887        );
888
889        let mdsc = Builder::default()
890            .with_endpoint(format!("http://{}", server.addr()))
891            .with_scopes(scopes)
892            .build()?;
893        let headers = mdsc.headers(Extensions::new()).await?;
894        assert_eq!(
895            get_token_from_headers(headers.clone()).unwrap(),
896            "test-access-token"
897        );
898        assert_eq!(
899            get_token_type_from_headers(headers).unwrap(),
900            "test-token-type"
901        );
902
903        Ok(())
904    }
905
906    #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
907    #[parallel]
908    async fn credentials_headers_retryable_error() -> TestResult {
909        let server = Server::run();
910        let scopes = vec!["scope1".to_string()];
911        server.expect(
912            Expectation::matching(all_of![
913                request::path(format!("{MDS_DEFAULT_URI}/token")),
914                request::query(url_decoded(contains(("scopes", scopes.join(",")))))
915            ])
916            .respond_with(status_code(503)),
917        );
918
919        let mdsc = Builder::default()
920            .with_endpoint(format!("http://{}", server.addr()))
921            .with_scopes(scopes)
922            .build()?;
923        let err = mdsc.headers(Extensions::new()).await.unwrap_err();
924        let original_err = find_source_error::<CredentialsError>(&err).unwrap();
925        assert!(original_err.is_transient());
926        let source = find_source_error::<reqwest::Error>(&err);
927        assert!(
928            matches!(source, Some(e) if e.status() == Some(StatusCode::SERVICE_UNAVAILABLE)),
929            "{err:?}"
930        );
931
932        Ok(())
933    }
934
935    #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
936    #[parallel]
937    async fn credentials_headers_nonretryable_error() -> TestResult {
938        let server = Server::run();
939        let scopes = vec!["scope1".to_string()];
940        server.expect(
941            Expectation::matching(all_of![
942                request::path(format!("{MDS_DEFAULT_URI}/token")),
943                request::query(url_decoded(contains(("scopes", scopes.join(",")))))
944            ])
945            .respond_with(status_code(401)),
946        );
947
948        let mdsc = Builder::default()
949            .with_endpoint(format!("http://{}", server.addr()))
950            .with_scopes(scopes)
951            .build()?;
952
953        let err = mdsc.headers(Extensions::new()).await.unwrap_err();
954        let original_err = find_source_error::<CredentialsError>(&err).unwrap();
955        assert!(!original_err.is_transient());
956        let source = find_source_error::<reqwest::Error>(&err);
957        assert!(
958            matches!(source, Some(e) if e.status() == Some(StatusCode::UNAUTHORIZED)),
959            "{err:?}"
960        );
961
962        Ok(())
963    }
964
965    #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
966    #[parallel]
967    async fn credentials_headers_malformed_response_is_nonretryable() -> TestResult {
968        let server = Server::run();
969        let scopes = vec!["scope1".to_string()];
970        server.expect(
971            Expectation::matching(all_of![
972                request::path(format!("{MDS_DEFAULT_URI}/token")),
973                request::query(url_decoded(contains(("scopes", scopes.join(",")))))
974            ])
975            .respond_with(json_encoded("bad json")),
976        );
977
978        let mdsc = Builder::default()
979            .with_endpoint(format!("http://{}", server.addr()))
980            .with_scopes(scopes)
981            .build()?;
982
983        let e = mdsc.headers(Extensions::new()).await.err().unwrap();
984        assert!(!e.is_transient());
985
986        Ok(())
987    }
988
989    #[tokio::test]
990    async fn get_default_universe_domain_success() -> TestResult {
991        let universe_domain_response = Builder::default().build()?.universe_domain().await.unwrap();
992        assert_eq!(universe_domain_response, DEFAULT_UNIVERSE_DOMAIN);
993        Ok(())
994    }
995
996    #[tokio::test]
997    async fn get_custom_universe_domain_success() -> TestResult {
998        let universe_domain = "test-universe";
999        let universe_domain_response = Builder::default()
1000            .with_universe_domain(universe_domain)
1001            .build()?
1002            .universe_domain()
1003            .await
1004            .unwrap();
1005        assert_eq!(universe_domain_response, universe_domain);
1006
1007        Ok(())
1008    }
1009}