google_cloud_auth/credentials/
mds.rs

1// Copyright 2024 Google LLC
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     https://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15//! [Metadata Service] Credentials type.
16//!
17//! Google Cloud environments such as [Google Compute Engine (GCE)][gce-link],
18//! [Google Kubernetes Engine (GKE)][gke-link], or [Cloud Run] provide a metadata service.
19//! This is a local service to the VM (or pod) which (as the name implies) provides
20//! metadata information about the VM. The service also provides access
21//! tokens associated with the [default service account] for the corresponding
22//! VM.
23//!
24//! The default host name of the metadata service is `metadata.google.internal`.
25//! If you would like to use a different hostname, you can set it using the
26//! `GCE_METADATA_HOST` environment variable.
27//!
28//! You can use this access token to securely authenticate with Google Cloud,
29//! without having to download secrets or other credentials. The types in this
30//! module allow you to retrieve these access tokens, and can be used with
31//! the Google Cloud client libraries for Rust.
32//!
33//! ## Example: Creating credentials with a custom quota project
34//!
35//! ```
36//! # use google_cloud_auth::credentials::mds::Builder;
37//! # use google_cloud_auth::credentials::Credentials;
38//! # use http::Extensions;
39//! # tokio_test::block_on(async {
40//! let credentials: Credentials = Builder::default()
41//!     .with_quota_project_id("my-quota-project")
42//!     .build()?;
43//! let headers = credentials.headers(Extensions::new()).await?;
44//! println!("Headers: {headers:?}");
45//! # Ok::<(), anyhow::Error>(())
46//! # });
47//! ```
48//!
49//! ## Example: Creating credentials with custom retry behavior
50//!
51//! ```
52//! # use google_cloud_auth::credentials::mds::Builder;
53//! # use google_cloud_auth::credentials::Credentials;
54//! # use http::Extensions;
55//! # use std::time::Duration;
56//! # tokio_test::block_on(async {
57//! use gax::retry_policy::{AlwaysRetry, RetryPolicyExt};
58//! use gax::exponential_backoff::ExponentialBackoff;
59//! let backoff = ExponentialBackoff::default();
60//! let credentials: Credentials = Builder::default()
61//!     .with_retry_policy(AlwaysRetry.with_attempt_limit(3))
62//!     .with_backoff_policy(backoff)
63//!     .build()?;
64//! let headers = credentials.headers(Extensions::new()).await?;
65//! println!("Headers: {headers:?}");
66//! # Ok::<(), anyhow::Error>(())
67//! # });
68//! ```
69//!
70//! [Application Default Credentials]: https://cloud.google.com/docs/authentication/application-default-credentials
71//! [Cloud Run]: https://cloud.google.com/run
72//! [default service account]: https://cloud.google.com/iam/docs/service-account-types#default
73//! [gce-link]: https://cloud.google.com/products/compute
74//! [gke-link]: https://cloud.google.com/kubernetes-engine
75//! [Metadata Service]: https://cloud.google.com/compute/docs/metadata/overview
76
77use crate::credentials::dynamic::CredentialsProvider;
78use crate::credentials::{CacheableResource, Credentials};
79use crate::errors::CredentialsError;
80use crate::headers_util::build_cacheable_headers;
81use crate::retry::{Builder as RetryTokenProviderBuilder, TokenProviderWithRetry};
82use crate::token::{CachedTokenProvider, Token, TokenProvider};
83use crate::token_cache::TokenCache;
84use crate::{BuildResult, Result};
85use async_trait::async_trait;
86use bon::Builder;
87use gax::backoff_policy::BackoffPolicyArg;
88use gax::retry_policy::RetryPolicyArg;
89use gax::retry_throttler::RetryThrottlerArg;
90use http::{Extensions, HeaderMap, HeaderValue};
91use reqwest::Client;
92use std::default::Default;
93use std::sync::Arc;
94use std::time::Duration;
95use tokio::time::Instant;
96
97pub(crate) const METADATA_FLAVOR_VALUE: &str = "Google";
98pub(crate) const METADATA_FLAVOR: &str = "metadata-flavor";
99pub(crate) const METADATA_ROOT: &str = "http://metadata.google.internal";
100pub(crate) const MDS_DEFAULT_URI: &str = "/computeMetadata/v1/instance/service-accounts/default";
101pub(crate) const GCE_METADATA_HOST_ENV_VAR: &str = "GCE_METADATA_HOST";
102// TODO(#2235) - Improve this message by talking about retries when really running with MDS
103const MDS_NOT_FOUND_ERROR: &str = concat!(
104    "Could not fetch an auth token to authenticate with Google Cloud. ",
105    "The most common reason for this problem is that you are not running in a Google Cloud Environment ",
106    "and you have not configured local credentials for development and testing. ",
107    "To setup local credentials, run `gcloud auth application-default login`. ",
108    "More information on how to authenticate client libraries can be found at https://cloud.google.com/docs/authentication/client-libraries"
109);
110
111#[derive(Debug)]
112struct MDSCredentials<T>
113where
114    T: CachedTokenProvider,
115{
116    quota_project_id: Option<String>,
117    token_provider: T,
118}
119
120/// Creates [Credentials] instances backed by the [Metadata Service].
121///
122/// While the Google Cloud client libraries for Rust default to credentials
123/// backed by the metadata service, some applications may need to:
124/// * Customize the metadata service credentials in some way
125/// * Bypass the [Application Default Credentials] lookup and only
126///   use the metadata server credentials
127/// * Use the credentials directly outside the client libraries
128///
129/// [Application Default Credentials]: https://cloud.google.com/docs/authentication/application-default-credentials
130/// [Metadata Service]: https://cloud.google.com/compute/docs/metadata/overview
131#[derive(Debug, Default)]
132pub struct Builder {
133    endpoint: Option<String>,
134    quota_project_id: Option<String>,
135    scopes: Option<Vec<String>>,
136    created_by_adc: bool,
137    retry_builder: RetryTokenProviderBuilder,
138}
139
140impl Builder {
141    /// Sets the endpoint for this credentials.
142    ///
143    /// A trailing slash is significant, so specify the base URL without a trailing  
144    /// slash. If not set, the credentials use `http://metadata.google.internal`.
145    ///
146    /// # Example
147    /// ```
148    /// # use google_cloud_auth::credentials::mds::Builder;
149    /// # tokio_test::block_on(async {
150    /// let credentials = Builder::default()
151    ///     .with_endpoint("https://metadata.google.foobar")
152    ///     .build();
153    /// # });
154    /// ```
155    pub fn with_endpoint<S: Into<String>>(mut self, endpoint: S) -> Self {
156        self.endpoint = Some(endpoint.into());
157        self
158    }
159
160    /// Set the [quota project] for this credentials.
161    ///
162    /// In some services, you can use a service account in
163    /// one project for authentication and authorization, and charge
164    /// the usage to a different project. This may require that the
165    /// service account has `serviceusage.services.use` permissions on the quota project.
166    ///
167    /// [quota project]: https://cloud.google.com/docs/quotas/quota-project
168    pub fn with_quota_project_id<S: Into<String>>(mut self, quota_project_id: S) -> Self {
169        self.quota_project_id = Some(quota_project_id.into());
170        self
171    }
172
173    /// Sets the [scopes] for this credentials.
174    ///
175    /// Metadata server issues tokens based on the requested scopes.
176    /// If no scopes are specified, the credentials defaults to all
177    /// scopes configured for the [default service account] on the instance.
178    ///
179    /// [default service account]: https://cloud.google.com/iam/docs/service-account-types#default
180    /// [scopes]: https://developers.google.com/identity/protocols/oauth2/scopes
181    pub fn with_scopes<I, S>(mut self, scopes: I) -> Self
182    where
183        I: IntoIterator<Item = S>,
184        S: Into<String>,
185    {
186        self.scopes = Some(scopes.into_iter().map(|s| s.into()).collect());
187        self
188    }
189
190    /// Configure the retry policy for fetching tokens.
191    ///
192    /// The retry policy controls how to handle retries, and sets limits on
193    /// the number of attempts or the total time spent retrying.
194    ///
195    /// ```
196    /// # use google_cloud_auth::credentials::mds::Builder;
197    /// # tokio_test::block_on(async {
198    /// use gax::retry_policy::{AlwaysRetry, RetryPolicyExt};
199    /// let credentials = Builder::default()
200    ///     .with_retry_policy(AlwaysRetry.with_attempt_limit(3))
201    ///     .build();
202    /// # });
203    /// ```
204    pub fn with_retry_policy<V: Into<RetryPolicyArg>>(mut self, v: V) -> Self {
205        self.retry_builder = self.retry_builder.with_retry_policy(v.into());
206        self
207    }
208
209    /// Configure the retry backoff policy.
210    ///
211    /// The backoff policy controls how long to wait in between retry attempts.
212    ///
213    /// ```
214    /// # use google_cloud_auth::credentials::mds::Builder;
215    /// # use std::time::Duration;
216    /// # tokio_test::block_on(async {
217    /// use gax::exponential_backoff::ExponentialBackoff;
218    /// let policy = ExponentialBackoff::default();
219    /// let credentials = Builder::default()
220    ///     .with_backoff_policy(policy)
221    ///     .build();
222    /// # });
223    /// ```
224    pub fn with_backoff_policy<V: Into<BackoffPolicyArg>>(mut self, v: V) -> Self {
225        self.retry_builder = self.retry_builder.with_backoff_policy(v.into());
226        self
227    }
228
229    /// Configure the retry throttler.
230    ///
231    /// Advanced applications may want to configure a retry throttler to
232    /// [Address Cascading Failures] and when [Handling Overload] conditions.
233    /// The authentication library throttles its retry loop, using a policy to
234    /// control the throttling algorithm. Use this method to fine tune or
235    /// customize the default retry throttler.
236    ///
237    /// [Handling Overload]: https://sre.google/sre-book/handling-overload/
238    /// [Address Cascading Failures]: https://sre.google/sre-book/addressing-cascading-failures/
239    ///
240    /// ```
241    /// # use google_cloud_auth::credentials::mds::Builder;
242    /// # tokio_test::block_on(async {
243    /// use gax::retry_throttler::AdaptiveThrottler;
244    /// let credentials = Builder::default()
245    ///     .with_retry_throttler(AdaptiveThrottler::default())
246    ///     .build();
247    /// # });
248    /// ```
249    pub fn with_retry_throttler<V: Into<RetryThrottlerArg>>(mut self, v: V) -> Self {
250        self.retry_builder = self.retry_builder.with_retry_throttler(v.into());
251        self
252    }
253
254    // This method is used to build mds credentials from ADC
255    pub(crate) fn from_adc() -> Self {
256        Self {
257            created_by_adc: true,
258            ..Default::default()
259        }
260    }
261
262    fn build_token_provider(self) -> TokenProviderWithRetry<MDSAccessTokenProvider> {
263        let final_endpoint: String;
264        let endpoint_overridden: bool;
265
266        // Determine the endpoint and whether it was overridden
267        if let Ok(host_from_env) = std::env::var(GCE_METADATA_HOST_ENV_VAR) {
268            // Check GCE_METADATA_HOST environment variable first
269            final_endpoint = format!("http://{host_from_env}");
270            endpoint_overridden = true;
271        } else if let Some(builder_endpoint) = self.endpoint {
272            // Else, check if an endpoint was provided to the mds::Builder
273            final_endpoint = builder_endpoint;
274            endpoint_overridden = true;
275        } else {
276            // Else, use the default metadata root
277            final_endpoint = METADATA_ROOT.to_string();
278            endpoint_overridden = false;
279        };
280
281        let tp = MDSAccessTokenProvider::builder()
282            .endpoint(final_endpoint)
283            .maybe_scopes(self.scopes)
284            .endpoint_overridden(endpoint_overridden)
285            .created_by_adc(self.created_by_adc)
286            .build();
287        self.retry_builder.build(tp)
288    }
289
290    /// Returns a [Credentials] instance with the configured settings.
291    pub fn build(self) -> BuildResult<Credentials> {
292        let mdsc = MDSCredentials {
293            quota_project_id: self.quota_project_id.clone(),
294            token_provider: TokenCache::new(self.build_token_provider()),
295        };
296        Ok(Credentials {
297            inner: Arc::new(mdsc),
298        })
299    }
300}
301
302#[async_trait::async_trait]
303impl<T> CredentialsProvider for MDSCredentials<T>
304where
305    T: CachedTokenProvider,
306{
307    async fn headers(&self, extensions: Extensions) -> Result<CacheableResource<HeaderMap>> {
308        let cached_token = self.token_provider.token(extensions).await?;
309        build_cacheable_headers(&cached_token, &self.quota_project_id)
310    }
311}
312
313#[derive(Clone, Debug, PartialEq, serde::Deserialize, serde::Serialize)]
314struct MDSTokenResponse {
315    access_token: String,
316    #[serde(skip_serializing_if = "Option::is_none")]
317    expires_in: Option<u64>,
318    token_type: String,
319}
320
321#[derive(Debug, Clone, Default, Builder)]
322struct MDSAccessTokenProvider {
323    #[builder(into)]
324    scopes: Option<Vec<String>>,
325    #[builder(into)]
326    endpoint: String,
327    endpoint_overridden: bool,
328    created_by_adc: bool,
329}
330
331impl MDSAccessTokenProvider {
332    // During ADC, if no credentials are found in the well-known location and the GOOGLE_APPLICATION_CREDENTIALS
333    // environment variable is not set, we default to MDS credentials without checking if the code is really
334    // running in an environment with MDS. To help users who got to this state because of lack of credentials
335    // setup on their machines, we provide a detailed error message to them talking about local setup and other
336    // auth mechanisms available to them.
337    // If the endpoint is overridden, even if ADC was used to create the MDS credentials, we do not give a detailed
338    // error message because they deliberately wanted to use an MDS.
339    fn error_message(&self) -> &str {
340        if self.use_adc_message() {
341            MDS_NOT_FOUND_ERROR
342        } else {
343            "failed to fetch token"
344        }
345    }
346
347    fn use_adc_message(&self) -> bool {
348        self.created_by_adc && !self.endpoint_overridden
349    }
350}
351
352#[async_trait]
353impl TokenProvider for MDSAccessTokenProvider {
354    async fn token(&self) -> Result<Token> {
355        let client = Client::new();
356        let request = client
357            .get(format!("{}{}/token", self.endpoint, MDS_DEFAULT_URI))
358            .header(
359                METADATA_FLAVOR,
360                HeaderValue::from_static(METADATA_FLAVOR_VALUE),
361            );
362        // Use the `scopes` option if set, otherwise let the MDS use the default
363        // scopes.
364        let scopes = self.scopes.as_ref().map(|v| v.join(","));
365        let request = scopes
366            .into_iter()
367            .fold(request, |r, s| r.query(&[("scopes", s)]));
368
369        // If the connection to MDS was not successful, it is useful to retry when really
370        // running on MDS environments and not useful if there is no MDS. We will mark the error
371        // as retryable and let the retry policy determine whether to retry or not. Whenever we
372        // define a default retry policy, we can skip retrying this case.
373        let response = request
374            .send()
375            .await
376            .map_err(|e| crate::errors::from_http_error(e, self.error_message()))?;
377        // Process the response
378        if !response.status().is_success() {
379            let err = crate::errors::from_http_response(response, self.error_message()).await;
380            return Err(err);
381        }
382        let response = response.json::<MDSTokenResponse>().await.map_err(|e| {
383            // Decoding errors are not transient. Typically they indicate a badly
384            // configured MDS endpoint, or DNS redirecting the request to a random
385            // server, e.g., ISPs that redirect unknown services to HTTP.
386            CredentialsError::from_source(!e.is_decode(), e)
387        })?;
388        let token = Token {
389            token: response.access_token,
390            token_type: response.token_type,
391            expires_at: response
392                .expires_in
393                .map(|d| Instant::now() + Duration::from_secs(d)),
394            metadata: None,
395        };
396        Ok(token)
397    }
398}
399
400#[cfg(test)]
401mod tests {
402    use super::*;
403    use crate::credentials::DEFAULT_UNIVERSE_DOMAIN;
404    use crate::credentials::QUOTA_PROJECT_KEY;
405    use crate::credentials::tests::{
406        find_source_error, get_headers_from_cache, get_mock_auth_retry_policy,
407        get_mock_backoff_policy, get_mock_retry_throttler, get_token_from_headers,
408        get_token_type_from_headers,
409    };
410    use crate::errors;
411    use crate::errors::CredentialsError;
412    use crate::token::tests::MockTokenProvider;
413    use http::HeaderValue;
414    use http::header::AUTHORIZATION;
415    use httptest::cycle;
416    use httptest::matchers::{all_of, contains, request, url_decoded};
417    use httptest::responders::{json_encoded, status_code};
418    use httptest::{Expectation, Server};
419    use reqwest::StatusCode;
420    use scoped_env::ScopedEnv;
421    use serial_test::{parallel, serial};
422    use std::error::Error;
423    use test_case::test_case;
424    use url::Url;
425
426    type TestResult = anyhow::Result<()>;
427
428    #[tokio::test]
429    #[parallel]
430    async fn test_mds_retries_on_transient_failures() -> TestResult {
431        let mut server = Server::run();
432        server.expect(
433            Expectation::matching(request::path(format!("{MDS_DEFAULT_URI}/token")))
434                .times(3)
435                .respond_with(status_code(503)),
436        );
437
438        let provider = Builder::default()
439            .with_endpoint(format!("http://{}", server.addr()))
440            .with_retry_policy(get_mock_auth_retry_policy(3))
441            .with_backoff_policy(get_mock_backoff_policy())
442            .with_retry_throttler(get_mock_retry_throttler())
443            .build_token_provider();
444
445        let err = provider.token().await.unwrap_err();
446        assert!(!err.is_transient());
447        server.verify_and_clear();
448        Ok(())
449    }
450
451    #[tokio::test]
452    #[parallel]
453    async fn test_mds_does_not_retry_on_non_transient_failures() -> TestResult {
454        let mut server = Server::run();
455        server.expect(
456            Expectation::matching(request::path(format!("{MDS_DEFAULT_URI}/token")))
457                .times(1)
458                .respond_with(status_code(401)),
459        );
460
461        let provider = Builder::default()
462            .with_endpoint(format!("http://{}", server.addr()))
463            .with_retry_policy(get_mock_auth_retry_policy(1))
464            .with_backoff_policy(get_mock_backoff_policy())
465            .with_retry_throttler(get_mock_retry_throttler())
466            .build_token_provider();
467
468        let err = provider.token().await.unwrap_err();
469        assert!(!err.is_transient());
470        server.verify_and_clear();
471        Ok(())
472    }
473
474    #[tokio::test]
475    #[parallel]
476    async fn test_mds_retries_for_success() -> TestResult {
477        let mut server = Server::run();
478        let response = MDSTokenResponse {
479            access_token: "test-access-token".to_string(),
480            expires_in: Some(3600),
481            token_type: "test-token-type".to_string(),
482        };
483
484        server.expect(
485            Expectation::matching(request::path(format!("{MDS_DEFAULT_URI}/token")))
486                .times(3)
487                .respond_with(cycle![
488                    status_code(503).body("try-again"),
489                    status_code(503).body("try-again"),
490                    status_code(200)
491                        .append_header("Content-Type", "application/json")
492                        .body(serde_json::to_string(&response).unwrap()),
493                ]),
494        );
495
496        let provider = Builder::default()
497            .with_endpoint(format!("http://{}", server.addr()))
498            .with_retry_policy(get_mock_auth_retry_policy(3))
499            .with_backoff_policy(get_mock_backoff_policy())
500            .with_retry_throttler(get_mock_retry_throttler())
501            .build_token_provider();
502
503        let token = provider.token().await?;
504        assert_eq!(token.token, "test-access-token");
505
506        server.verify_and_clear();
507        Ok(())
508    }
509
510    #[test]
511    fn validate_default_endpoint_urls() {
512        let default_endpoint_address = Url::parse(&format!("{METADATA_ROOT}{MDS_DEFAULT_URI}"));
513        assert!(default_endpoint_address.is_ok());
514
515        let token_endpoint_address = Url::parse(&format!("{METADATA_ROOT}{MDS_DEFAULT_URI}/token"));
516        assert!(token_endpoint_address.is_ok());
517    }
518
519    #[tokio::test]
520    async fn headers_success() -> TestResult {
521        let token = Token {
522            token: "test-token".to_string(),
523            token_type: "Bearer".to_string(),
524            expires_at: None,
525            metadata: None,
526        };
527
528        let mut mock = MockTokenProvider::new();
529        mock.expect_token().times(1).return_once(|| Ok(token));
530
531        let mdsc = MDSCredentials {
532            quota_project_id: None,
533            token_provider: TokenCache::new(mock),
534        };
535
536        let mut extensions = Extensions::new();
537        let cached_headers = mdsc.headers(extensions.clone()).await.unwrap();
538        let (headers, entity_tag) = match cached_headers {
539            CacheableResource::New { entity_tag, data } => (data, entity_tag),
540            CacheableResource::NotModified => unreachable!("expecting new headers"),
541        };
542        let token = headers.get(AUTHORIZATION).unwrap();
543        assert_eq!(headers.len(), 1, "{headers:?}");
544        assert_eq!(token, HeaderValue::from_static("Bearer test-token"));
545        assert!(token.is_sensitive());
546
547        extensions.insert(entity_tag);
548
549        let cached_headers = mdsc.headers(extensions).await?;
550
551        match cached_headers {
552            CacheableResource::New { .. } => unreachable!("expecting new headers"),
553            CacheableResource::NotModified => CacheableResource::<HeaderMap>::NotModified,
554        };
555        Ok(())
556    }
557
558    #[tokio::test]
559    async fn headers_failure() {
560        let mut mock = MockTokenProvider::new();
561        mock.expect_token()
562            .times(1)
563            .return_once(|| Err(errors::non_retryable_from_str("fail")));
564
565        let mdsc = MDSCredentials {
566            quota_project_id: None,
567            token_provider: TokenCache::new(mock),
568        };
569        assert!(mdsc.headers(Extensions::new()).await.is_err());
570    }
571
572    #[test]
573    fn error_message_with_adc() {
574        let provider = MDSAccessTokenProvider::builder()
575            .endpoint("http://127.0.0.1")
576            .created_by_adc(true)
577            .endpoint_overridden(false)
578            .build();
579
580        let want = MDS_NOT_FOUND_ERROR;
581        let got = provider.error_message();
582        assert!(got.contains(want), "{got}, {provider:?}");
583    }
584
585    #[test_case(false, false)]
586    #[test_case(false, true)]
587    #[test_case(true, true)]
588    fn error_message_without_adc(adc: bool, overridden: bool) {
589        let provider = MDSAccessTokenProvider::builder()
590            .endpoint("http://127.0.0.1")
591            .created_by_adc(adc)
592            .endpoint_overridden(overridden)
593            .build();
594
595        let not_want = MDS_NOT_FOUND_ERROR;
596        let got = provider.error_message();
597        assert!(!got.contains(not_want), "{got}, {provider:?}");
598    }
599
600    #[tokio::test]
601    #[serial]
602    async fn adc_no_mds() -> TestResult {
603        let Err(err) = Builder::from_adc().build_token_provider().token().await else {
604            // The environment has an MDS, skip the test.
605            return Ok(());
606        };
607
608        let original_err = find_source_error::<CredentialsError>(&err).unwrap();
609        assert!(
610            original_err.to_string().contains("application-default"),
611            "display={err}, debug={err:?}"
612        );
613
614        Ok(())
615    }
616
617    #[tokio::test]
618    #[serial]
619    async fn adc_overridden_mds() -> TestResult {
620        let _e = ScopedEnv::set(super::GCE_METADATA_HOST_ENV_VAR, "metadata.overridden");
621
622        let err = Builder::from_adc()
623            .build_token_provider()
624            .token()
625            .await
626            .unwrap_err();
627
628        let _e = ScopedEnv::remove(super::GCE_METADATA_HOST_ENV_VAR);
629
630        let original_err = find_source_error::<CredentialsError>(&err).unwrap();
631        assert!(original_err.is_transient());
632        assert!(
633            !original_err.to_string().contains("application-default"),
634            "display={err}, debug={err:?}"
635        );
636        let source = find_source_error::<reqwest::Error>(&err);
637        assert!(matches!(source, Some(e) if e.status().is_none()), "{err:?}");
638
639        Ok(())
640    }
641
642    #[tokio::test]
643    #[serial]
644    async fn builder_no_mds() -> TestResult {
645        let Err(e) = Builder::default().build_token_provider().token().await else {
646            // The environment has an MDS, skip the test.
647            return Ok(());
648        };
649
650        let original_err = find_source_error::<CredentialsError>(&e).unwrap();
651        assert!(
652            !format!("{:?}", original_err.source()).contains("application-default"),
653            "{e:?}"
654        );
655
656        Ok(())
657    }
658
659    #[tokio::test]
660    #[serial]
661    async fn test_gce_metadata_host_env_var() -> TestResult {
662        let server = Server::run();
663        let scopes = ["scope1", "scope2"];
664        let response = MDSTokenResponse {
665            access_token: "test-access-token".to_string(),
666            expires_in: Some(3600),
667            token_type: "test-token-type".to_string(),
668        };
669        server.expect(
670            Expectation::matching(all_of![
671                request::path(format!("{MDS_DEFAULT_URI}/token")),
672                request::query(url_decoded(contains(("scopes", scopes.join(",")))))
673            ])
674            .respond_with(json_encoded(response)),
675        );
676
677        let addr = server.addr().to_string();
678        let _e = ScopedEnv::set(super::GCE_METADATA_HOST_ENV_VAR, &addr);
679        let mdsc = Builder::default()
680            .with_scopes(["scope1", "scope2"])
681            .build()
682            .unwrap();
683        let headers = mdsc.headers(Extensions::new()).await.unwrap();
684        let _e = ScopedEnv::remove(super::GCE_METADATA_HOST_ENV_VAR);
685
686        assert_eq!(
687            get_token_from_headers(headers).unwrap(),
688            "test-access-token"
689        );
690        Ok(())
691    }
692
693    #[tokio::test]
694    #[parallel]
695    async fn headers_success_with_quota_project() -> TestResult {
696        let server = Server::run();
697        let scopes = ["scope1", "scope2"];
698        let response = MDSTokenResponse {
699            access_token: "test-access-token".to_string(),
700            expires_in: Some(3600),
701            token_type: "test-token-type".to_string(),
702        };
703        server.expect(
704            Expectation::matching(all_of![
705                request::path(format!("{MDS_DEFAULT_URI}/token")),
706                request::query(url_decoded(contains(("scopes", scopes.join(",")))))
707            ])
708            .respond_with(json_encoded(response)),
709        );
710
711        let mdsc = Builder::default()
712            .with_scopes(["scope1", "scope2"])
713            .with_endpoint(format!("http://{}", server.addr()))
714            .with_quota_project_id("test-project")
715            .build()?;
716
717        let headers = get_headers_from_cache(mdsc.headers(Extensions::new()).await.unwrap())?;
718        let token = headers.get(AUTHORIZATION).unwrap();
719        let quota_project = headers.get(QUOTA_PROJECT_KEY).unwrap();
720
721        assert_eq!(headers.len(), 2, "{headers:?}");
722        assert_eq!(
723            token,
724            HeaderValue::from_static("test-token-type test-access-token")
725        );
726        assert!(token.is_sensitive());
727        assert_eq!(quota_project, HeaderValue::from_static("test-project"));
728        assert!(!quota_project.is_sensitive());
729
730        Ok(())
731    }
732
733    #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
734    #[parallel]
735    async fn token_caching() -> TestResult {
736        let mut server = Server::run();
737        let scopes = vec!["scope1".to_string()];
738        let response = MDSTokenResponse {
739            access_token: "test-access-token".to_string(),
740            expires_in: Some(3600),
741            token_type: "test-token-type".to_string(),
742        };
743        server.expect(
744            Expectation::matching(all_of![
745                request::path(format!("{MDS_DEFAULT_URI}/token")),
746                request::query(url_decoded(contains(("scopes", scopes.join(",")))))
747            ])
748            .times(1)
749            .respond_with(json_encoded(response)),
750        );
751
752        let mdsc = Builder::default()
753            .with_scopes(scopes)
754            .with_endpoint(format!("http://{}", server.addr()))
755            .build()?;
756        let headers = mdsc.headers(Extensions::new()).await?;
757        assert_eq!(
758            get_token_from_headers(headers).unwrap(),
759            "test-access-token"
760        );
761        let headers = mdsc.headers(Extensions::new()).await?;
762        assert_eq!(
763            get_token_from_headers(headers).unwrap(),
764            "test-access-token"
765        );
766
767        // validate that the inner token provider is called only once
768        server.verify_and_clear();
769
770        Ok(())
771    }
772
773    #[tokio::test(start_paused = true)]
774    #[parallel]
775    async fn token_provider_full() -> TestResult {
776        let server = Server::run();
777        let scopes = vec!["scope1".to_string()];
778        let response = MDSTokenResponse {
779            access_token: "test-access-token".to_string(),
780            expires_in: Some(3600),
781            token_type: "test-token-type".to_string(),
782        };
783        server.expect(
784            Expectation::matching(all_of![
785                request::path(format!("{MDS_DEFAULT_URI}/token")),
786                request::query(url_decoded(contains(("scopes", scopes.join(",")))))
787            ])
788            .respond_with(json_encoded(response)),
789        );
790
791        let token = Builder::default()
792            .with_endpoint(format!("http://{}", server.addr()))
793            .with_scopes(scopes)
794            .build_token_provider()
795            .token()
796            .await?;
797
798        let now = tokio::time::Instant::now();
799        assert_eq!(token.token, "test-access-token");
800        assert_eq!(token.token_type, "test-token-type");
801        assert!(
802            token
803                .expires_at
804                .is_some_and(|d| d >= now + Duration::from_secs(3600))
805        );
806
807        Ok(())
808    }
809
810    #[tokio::test(start_paused = true)]
811    #[parallel]
812    async fn token_provider_full_no_scopes() -> TestResult {
813        let server = Server::run();
814        let response = MDSTokenResponse {
815            access_token: "test-access-token".to_string(),
816            expires_in: Some(3600),
817            token_type: "test-token-type".to_string(),
818        };
819        server.expect(
820            Expectation::matching(request::path(format!("{MDS_DEFAULT_URI}/token")))
821                .respond_with(json_encoded(response)),
822        );
823
824        let token = Builder::default()
825            .with_endpoint(format!("http://{}", server.addr()))
826            .build_token_provider()
827            .token()
828            .await?;
829
830        let now = Instant::now();
831        assert_eq!(token.token, "test-access-token");
832        assert_eq!(token.token_type, "test-token-type");
833        assert!(
834            token
835                .expires_at
836                .is_some_and(|d| d == now + Duration::from_secs(3600))
837        );
838
839        Ok(())
840    }
841
842    #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
843    #[parallel]
844    async fn credential_provider_full() -> TestResult {
845        let server = Server::run();
846        let scopes = vec!["scope1".to_string()];
847        let response = MDSTokenResponse {
848            access_token: "test-access-token".to_string(),
849            expires_in: None,
850            token_type: "test-token-type".to_string(),
851        };
852        server.expect(
853            Expectation::matching(all_of![
854                request::path(format!("{MDS_DEFAULT_URI}/token")),
855                request::query(url_decoded(contains(("scopes", scopes.join(",")))))
856            ])
857            .respond_with(json_encoded(response)),
858        );
859
860        let mdsc = Builder::default()
861            .with_endpoint(format!("http://{}", server.addr()))
862            .with_scopes(scopes)
863            .build()?;
864        let headers = mdsc.headers(Extensions::new()).await?;
865        assert_eq!(
866            get_token_from_headers(headers.clone()).unwrap(),
867            "test-access-token"
868        );
869        assert_eq!(
870            get_token_type_from_headers(headers).unwrap(),
871            "test-token-type"
872        );
873
874        Ok(())
875    }
876
877    #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
878    #[parallel]
879    async fn credentials_headers_retryable_error() -> TestResult {
880        let server = Server::run();
881        let scopes = vec!["scope1".to_string()];
882        server.expect(
883            Expectation::matching(all_of![
884                request::path(format!("{MDS_DEFAULT_URI}/token")),
885                request::query(url_decoded(contains(("scopes", scopes.join(",")))))
886            ])
887            .respond_with(status_code(503)),
888        );
889
890        let mdsc = Builder::default()
891            .with_endpoint(format!("http://{}", server.addr()))
892            .with_scopes(scopes)
893            .build()?;
894        let err = mdsc.headers(Extensions::new()).await.unwrap_err();
895        let original_err = find_source_error::<CredentialsError>(&err).unwrap();
896        assert!(original_err.is_transient());
897        let source = find_source_error::<reqwest::Error>(&err);
898        assert!(
899            matches!(source, Some(e) if e.status() == Some(StatusCode::SERVICE_UNAVAILABLE)),
900            "{err:?}"
901        );
902
903        Ok(())
904    }
905
906    #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
907    #[parallel]
908    async fn credentials_headers_nonretryable_error() -> TestResult {
909        let server = Server::run();
910        let scopes = vec!["scope1".to_string()];
911        server.expect(
912            Expectation::matching(all_of![
913                request::path(format!("{MDS_DEFAULT_URI}/token")),
914                request::query(url_decoded(contains(("scopes", scopes.join(",")))))
915            ])
916            .respond_with(status_code(401)),
917        );
918
919        let mdsc = Builder::default()
920            .with_endpoint(format!("http://{}", server.addr()))
921            .with_scopes(scopes)
922            .build()?;
923
924        let err = mdsc.headers(Extensions::new()).await.unwrap_err();
925        let original_err = find_source_error::<CredentialsError>(&err).unwrap();
926        assert!(!original_err.is_transient());
927        let source = find_source_error::<reqwest::Error>(&err);
928        assert!(
929            matches!(source, Some(e) if e.status() == Some(StatusCode::UNAUTHORIZED)),
930            "{err:?}"
931        );
932
933        Ok(())
934    }
935
936    #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
937    #[parallel]
938    async fn credentials_headers_malformed_response_is_nonretryable() -> TestResult {
939        let server = Server::run();
940        let scopes = vec!["scope1".to_string()];
941        server.expect(
942            Expectation::matching(all_of![
943                request::path(format!("{MDS_DEFAULT_URI}/token")),
944                request::query(url_decoded(contains(("scopes", scopes.join(",")))))
945            ])
946            .respond_with(json_encoded("bad json")),
947        );
948
949        let mdsc = Builder::default()
950            .with_endpoint(format!("http://{}", server.addr()))
951            .with_scopes(scopes)
952            .build()?;
953
954        let e = mdsc.headers(Extensions::new()).await.err().unwrap();
955        assert!(!e.is_transient());
956
957        Ok(())
958    }
959
960    #[tokio::test]
961    async fn get_default_universe_domain_success() -> TestResult {
962        let universe_domain_response = Builder::default().build()?.universe_domain().await.unwrap();
963        assert_eq!(universe_domain_response, DEFAULT_UNIVERSE_DOMAIN);
964        Ok(())
965    }
966}