google_cloud_auth/credentials/
mds.rs

1// Copyright 2024 Google LLC
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     https://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15//! [Metadata Service] Credentials type.
16//!
17//! Google Cloud environments such as [Google Compute Engine (GCE)][gce-link],
18//! [Google Kubernetes Engine (GKE)][gke-link], or [Cloud Run] provide a metadata service.
19//! This is a local service to the VM (or pod) which (as the name implies) provides
20//! metadata information about the VM. The service also provides access
21//! tokens associated with the [default service account] for the corresponding
22//! VM.
23//!
24//! The default host name of the metadata service is `metadata.google.internal`.
25//! If you would like to use a different hostname, you can set it using the
26//! `GCE_METADATA_HOST` environment variable.
27//!
28//! You can use this access token to securely authenticate with Google Cloud,
29//! without having to download secrets or other credentials. The types in this
30//! module allow you to retrieve these access tokens, and can be used with
31//! the Google Cloud client libraries for Rust.
32//!
33//! ## Example: Creating credentials with a custom quota project
34//!
35//! ```
36//! # use google_cloud_auth::credentials::mds::Builder;
37//! # use google_cloud_auth::credentials::Credentials;
38//! # use http::Extensions;
39//! # tokio_test::block_on(async {
40//! let credentials: Credentials = Builder::default()
41//!     .with_quota_project_id("my-quota-project")
42//!     .build()?;
43//! let headers = credentials.headers(Extensions::new()).await?;
44//! println!("Headers: {headers:?}");
45//! # Ok::<(), anyhow::Error>(())
46//! # });
47//! ```
48//!
49//! ## Example: Creating credentials with custom retry behavior
50//!
51//! ```
52//! # use google_cloud_auth::credentials::mds::Builder;
53//! # use google_cloud_auth::credentials::Credentials;
54//! # use http::Extensions;
55//! # use std::time::Duration;
56//! # tokio_test::block_on(async {
57//! use gax::retry_policy::{AlwaysRetry, RetryPolicyExt};
58//! use gax::exponential_backoff::ExponentialBackoff;
59//! let backoff = ExponentialBackoff::default();
60//! let credentials: Credentials = Builder::default()
61//!     .with_retry_policy(AlwaysRetry.with_attempt_limit(3))
62//!     .with_backoff_policy(backoff)
63//!     .build()?;
64//! let headers = credentials.headers(Extensions::new()).await?;
65//! println!("Headers: {headers:?}");
66//! # Ok::<(), anyhow::Error>(())
67//! # });
68//! ```
69//!
70//! [Application Default Credentials]: https://cloud.google.com/docs/authentication/application-default-credentials
71//! [Cloud Run]: https://cloud.google.com/run
72//! [default service account]: https://cloud.google.com/iam/docs/service-account-types#default
73//! [gce-link]: https://cloud.google.com/products/compute
74//! [gke-link]: https://cloud.google.com/kubernetes-engine
75//! [Metadata Service]: https://cloud.google.com/compute/docs/metadata/overview
76
77use crate::credentials::dynamic::CredentialsProvider;
78use crate::credentials::{CacheableResource, Credentials};
79use crate::errors::CredentialsError;
80use crate::headers_util::build_cacheable_headers;
81use crate::retry::{Builder as RetryTokenProviderBuilder, TokenProviderWithRetry};
82use crate::token::{CachedTokenProvider, Token, TokenProvider};
83use crate::token_cache::TokenCache;
84use crate::{BuildResult, Result};
85use async_trait::async_trait;
86use bon::Builder;
87use gax::backoff_policy::BackoffPolicyArg;
88use gax::retry_policy::RetryPolicyArg;
89use gax::retry_throttler::RetryThrottlerArg;
90use http::{Extensions, HeaderMap, HeaderValue};
91use reqwest::Client;
92use std::default::Default;
93use std::sync::Arc;
94use std::time::Duration;
95use tokio::time::Instant;
96
97const METADATA_FLAVOR_VALUE: &str = "Google";
98const METADATA_FLAVOR: &str = "metadata-flavor";
99const METADATA_ROOT: &str = "http://metadata.google.internal";
100const MDS_DEFAULT_URI: &str = "/computeMetadata/v1/instance/service-accounts/default";
101const GCE_METADATA_HOST_ENV_VAR: &str = "GCE_METADATA_HOST";
102// TODO(#2235) - Improve this message by talking about retries when really running with MDS
103const MDS_NOT_FOUND_ERROR: &str = concat!(
104    "Could not fetch an auth token to authenticate with Google Cloud. ",
105    "The most common reason for this problem is that you are not running in a Google Cloud Environment ",
106    "and you have not configured local credentials for development and testing. ",
107    "To setup local credentials, run `gcloud auth application-default login`. ",
108    "More information on how to authenticate client libraries can be found at https://cloud.google.com/docs/authentication/client-libraries"
109);
110
111#[derive(Debug)]
112struct MDSCredentials<T>
113where
114    T: CachedTokenProvider,
115{
116    quota_project_id: Option<String>,
117    token_provider: T,
118}
119
120/// Creates [Credentials] instances backed by the [Metadata Service].
121///
122/// While the Google Cloud client libraries for Rust default to credentials
123/// backed by the metadata service, some applications may need to:
124/// * Customize the metadata service credentials in some way
125/// * Bypass the [Application Default Credentials] lookup and only
126///   use the metadata server credentials
127/// * Use the credentials directly outside the client libraries
128///
129/// [Application Default Credentials]: https://cloud.google.com/docs/authentication/application-default-credentials
130/// [Metadata Service]: https://cloud.google.com/compute/docs/metadata/overview
131#[derive(Debug, Default)]
132pub struct Builder {
133    endpoint: Option<String>,
134    quota_project_id: Option<String>,
135    scopes: Option<Vec<String>>,
136    created_by_adc: bool,
137    retry_builder: RetryTokenProviderBuilder,
138}
139
140impl Builder {
141    /// Sets the endpoint for this credentials.
142    ///
143    /// A trailing slash is significant, so specify the base URL without a trailing  
144    /// slash. If not set, the credentials use `http://metadata.google.internal`.
145    ///
146    /// # Example
147    /// ```
148    /// # use google_cloud_auth::credentials::mds::Builder;
149    /// # tokio_test::block_on(async {
150    /// let credentials = Builder::default()
151    ///     .with_endpoint("https://metadata.google.foobar")
152    ///     .build();
153    /// # });
154    /// ```
155    pub fn with_endpoint<S: Into<String>>(mut self, endpoint: S) -> Self {
156        self.endpoint = Some(endpoint.into());
157        self
158    }
159
160    /// Set the [quota project] for this credentials.
161    ///
162    /// In some services, you can use a service account in
163    /// one project for authentication and authorization, and charge
164    /// the usage to a different project. This may require that the
165    /// service account has `serviceusage.services.use` permissions on the quota project.
166    ///
167    /// [quota project]: https://cloud.google.com/docs/quotas/quota-project
168    pub fn with_quota_project_id<S: Into<String>>(mut self, quota_project_id: S) -> Self {
169        self.quota_project_id = Some(quota_project_id.into());
170        self
171    }
172
173    /// Sets the [scopes] for this credentials.
174    ///
175    /// Metadata server issues tokens based on the requested scopes.
176    /// If no scopes are specified, the credentials defaults to all
177    /// scopes configured for the [default service account] on the instance.
178    ///
179    /// [default service account]: https://cloud.google.com/iam/docs/service-account-types#default
180    /// [scopes]: https://developers.google.com/identity/protocols/oauth2/scopes
181    pub fn with_scopes<I, S>(mut self, scopes: I) -> Self
182    where
183        I: IntoIterator<Item = S>,
184        S: Into<String>,
185    {
186        self.scopes = Some(scopes.into_iter().map(|s| s.into()).collect());
187        self
188    }
189
190    /// Configure the retry policy for fetching tokens.
191    ///
192    /// The retry policy controls how to handle retries, and sets limits on
193    /// the number of attempts or the total time spent retrying.
194    ///
195    /// ```
196    /// # use google_cloud_auth::credentials::mds::Builder;
197    /// # tokio_test::block_on(async {
198    /// use gax::retry_policy::{AlwaysRetry, RetryPolicyExt};
199    /// let credentials = Builder::default()
200    ///     .with_retry_policy(AlwaysRetry.with_attempt_limit(3))
201    ///     .build();
202    /// # });
203    /// ```
204    pub fn with_retry_policy<V: Into<RetryPolicyArg>>(mut self, v: V) -> Self {
205        self.retry_builder = self.retry_builder.with_retry_policy(v.into());
206        self
207    }
208
209    /// Configure the retry backoff policy.
210    ///
211    /// The backoff policy controls how long to wait in between retry attempts.
212    ///
213    /// ```
214    /// # use google_cloud_auth::credentials::mds::Builder;
215    /// # use std::time::Duration;
216    /// # tokio_test::block_on(async {
217    /// use gax::exponential_backoff::ExponentialBackoff;
218    /// let policy = ExponentialBackoff::default();
219    /// let credentials = Builder::default()
220    ///     .with_backoff_policy(policy)
221    ///     .build();
222    /// # });
223    /// ```
224    pub fn with_backoff_policy<V: Into<BackoffPolicyArg>>(mut self, v: V) -> Self {
225        self.retry_builder = self.retry_builder.with_backoff_policy(v.into());
226        self
227    }
228
229    /// Configure the retry throttler.
230    ///
231    /// Advanced applications may want to configure a retry throttler to
232    /// [Address Cascading Failures] and when [Handling Overload] conditions.
233    /// The authentication library throttles its retry loop, using a policy to
234    /// control the throttling algorithm. Use this method to fine tune or
235    /// customize the default retry throttler.
236    ///
237    /// [Handling Overload]: https://sre.google/sre-book/handling-overload/
238    /// [Address Cascading Failures]: https://sre.google/sre-book/addressing-cascading-failures/
239    ///
240    /// ```
241    /// # use google_cloud_auth::credentials::mds::Builder;
242    /// # tokio_test::block_on(async {
243    /// use gax::retry_throttler::AdaptiveThrottler;
244    /// let credentials = Builder::default()
245    ///     .with_retry_throttler(AdaptiveThrottler::default())
246    ///     .build();
247    /// # });
248    /// ```
249    pub fn with_retry_throttler<V: Into<RetryThrottlerArg>>(mut self, v: V) -> Self {
250        self.retry_builder = self.retry_builder.with_retry_throttler(v.into());
251        self
252    }
253
254    // This method is used to build mds credentials from ADC
255    pub(crate) fn from_adc() -> Self {
256        Self {
257            created_by_adc: true,
258            ..Default::default()
259        }
260    }
261
262    fn build_token_provider(self) -> TokenProviderWithRetry<MDSAccessTokenProvider> {
263        let final_endpoint: String;
264        let endpoint_overridden: bool;
265
266        // Determine the endpoint and whether it was overridden
267        if let Ok(host_from_env) = std::env::var(GCE_METADATA_HOST_ENV_VAR) {
268            // Check GCE_METADATA_HOST environment variable first
269            final_endpoint = format!("http://{host_from_env}");
270            endpoint_overridden = true;
271        } else if let Some(builder_endpoint) = self.endpoint {
272            // Else, check if an endpoint was provided to the mds::Builder
273            final_endpoint = builder_endpoint;
274            endpoint_overridden = true;
275        } else {
276            // Else, use the default metadata root
277            final_endpoint = METADATA_ROOT.to_string();
278            endpoint_overridden = false;
279        };
280
281        let tp = MDSAccessTokenProvider::builder()
282            .endpoint(final_endpoint)
283            .maybe_scopes(self.scopes)
284            .endpoint_overridden(endpoint_overridden)
285            .created_by_adc(self.created_by_adc)
286            .build();
287        self.retry_builder.build(tp)
288    }
289
290    /// Returns a [Credentials] instance with the configured settings.
291    pub fn build(self) -> BuildResult<Credentials> {
292        let mdsc = MDSCredentials {
293            quota_project_id: self.quota_project_id.clone(),
294            token_provider: TokenCache::new(self.build_token_provider()),
295        };
296        Ok(Credentials {
297            inner: Arc::new(mdsc),
298        })
299    }
300}
301
302#[async_trait::async_trait]
303impl<T> CredentialsProvider for MDSCredentials<T>
304where
305    T: CachedTokenProvider,
306{
307    async fn headers(&self, extensions: Extensions) -> Result<CacheableResource<HeaderMap>> {
308        let cached_token = self.token_provider.token(extensions).await?;
309        build_cacheable_headers(&cached_token, &self.quota_project_id)
310    }
311}
312
313#[derive(Clone, Debug, PartialEq, serde::Deserialize, serde::Serialize)]
314struct MDSTokenResponse {
315    access_token: String,
316    #[serde(skip_serializing_if = "Option::is_none")]
317    expires_in: Option<u64>,
318    token_type: String,
319}
320
321#[derive(Debug, Clone, Default, Builder)]
322struct MDSAccessTokenProvider {
323    #[builder(into)]
324    scopes: Option<Vec<String>>,
325    #[builder(into)]
326    endpoint: String,
327    endpoint_overridden: bool,
328    created_by_adc: bool,
329}
330
331impl MDSAccessTokenProvider {
332    // During ADC, if no credentials are found in the well-known location and the GOOGLE_APPLICATION_CREDENTIALS
333    // environment variable is not set, we default to MDS credentials without checking if the code is really
334    // running in an environment with MDS. To help users who got to this state because of lack of credentials
335    // setup on their machines, we provide a detailed error message to them talking about local setup and other
336    // auth mechanisms available to them.
337    // If the endpoint is overridden, even if ADC was used to create the MDS credentials, we do not give a detailed
338    // error message because they deliberately wanted to use an MDS.
339    fn error_message(&self) -> &str {
340        if self.use_adc_message() {
341            MDS_NOT_FOUND_ERROR
342        } else {
343            "failed to fetch token"
344        }
345    }
346
347    fn use_adc_message(&self) -> bool {
348        self.created_by_adc && !self.endpoint_overridden
349    }
350}
351
352#[async_trait]
353impl TokenProvider for MDSAccessTokenProvider {
354    async fn token(&self) -> Result<Token> {
355        let client = Client::new();
356        let request = client
357            .get(format!("{}{}/token", self.endpoint, MDS_DEFAULT_URI))
358            .header(
359                METADATA_FLAVOR,
360                HeaderValue::from_static(METADATA_FLAVOR_VALUE),
361            );
362        // Use the `scopes` option if set, otherwise let the MDS use the default
363        // scopes.
364        let scopes = self.scopes.as_ref().map(|v| v.join(","));
365        let request = scopes
366            .into_iter()
367            .fold(request, |r, s| r.query(&[("scopes", s)]));
368
369        // If the connection to MDS was not successful, it is useful to retry when really
370        // running on MDS environments and not useful if there is no MDS. We will mark the error
371        // as retryable and let the retry policy determine whether to retry or not. Whenever we
372        // define a default retry policy, we can skip retrying this case.
373        let response = request
374            .send()
375            .await
376            .map_err(|e| crate::errors::from_http_error(e, self.error_message()))?;
377        // Process the response
378        if !response.status().is_success() {
379            let err = crate::errors::from_http_response(response, self.error_message()).await;
380            return Err(err);
381        }
382        let response = response.json::<MDSTokenResponse>().await.map_err(|e| {
383            // Decoding errors are not transient. Typically they indicate a badly
384            // configured MDS endpoint, or DNS redirecting the request to a random
385            // server, e.g., ISPs that redirect unknown services to HTTP.
386            CredentialsError::from_source(!e.is_decode(), e)
387        })?;
388        let token = Token {
389            token: response.access_token,
390            token_type: response.token_type,
391            expires_at: response
392                .expires_in
393                .map(|d| Instant::now() + Duration::from_secs(d)),
394            metadata: None,
395        };
396        Ok(token)
397    }
398}
399
400#[allow(dead_code)]
401pub(crate) mod idtoken {
402    //! Types for fetching ID tokens from the metadata service.
403    use std::sync::Arc;
404
405    use super::{
406        GCE_METADATA_HOST_ENV_VAR, MDS_DEFAULT_URI, METADATA_FLAVOR, METADATA_FLAVOR_VALUE,
407        METADATA_ROOT,
408    };
409    use crate::Result;
410    use crate::errors::CredentialsError;
411    use crate::token::{Token, TokenProvider};
412    use crate::{
413        BuildResult,
414        credentials::idtoken::{IDTokenCredentials, dynamic::IDTokenCredentialsProvider},
415    };
416    use async_trait::async_trait;
417    use http::HeaderValue;
418    use reqwest::Client;
419
420    #[derive(Debug)]
421    pub(crate) struct MDSCredentials<T>
422    where
423        T: TokenProvider,
424    {
425        token_provider: T,
426    }
427
428    #[async_trait]
429    impl<T> IDTokenCredentialsProvider for MDSCredentials<T>
430    where
431        T: TokenProvider,
432    {
433        async fn id_token(&self) -> Result<Token> {
434            self.token_provider.token().await
435        }
436    }
437
438    /// Creates [`IDTokenCredentials`] instances that fetch ID tokens from the
439    /// metadata service.
440    #[derive(Debug, Default)]
441    pub struct Builder {
442        endpoint: Option<String>,
443        format: Option<String>,
444        licenses: Option<String>,
445        target_audience: String,
446    }
447
448    impl Builder {
449        /// Creates a new `Builder`.
450        ///
451        /// The `target_audience` is a required parameter that specifies the
452        /// intended audience of the ID token. This is typically the URL of the
453        /// service that will be receiving the token.
454        pub fn new<S: Into<String>>(target_audience: S) -> Self {
455            Builder {
456                format: None,
457                endpoint: None,
458                licenses: None,
459                target_audience: target_audience.into(),
460            }
461        }
462
463        /// Sets the endpoint for this credentials.
464        ///
465        /// A trailing slash is significant, so specify the base URL without a trailing  
466        /// slash. If not set, the credentials use `http://metadata.google.internal`.
467        pub fn with_endpoint<S: Into<String>>(mut self, endpoint: S) -> Self {
468            self.endpoint = Some(endpoint.into());
469            self
470        }
471
472        /// Sets the [format] of the token.
473        ///
474        /// Specifies whether or not the project and instance details are included in the payload.
475        /// Specify `full` to include this information in the payload or `standard` to omit the information
476        /// from the payload. The default value is `standard``.
477        ///
478        /// [format]: https://cloud.google.com/compute/docs/instances/verifying-instance-identity#token_format
479        pub fn with_format<S: Into<String>>(mut self, format: S) -> Self {
480            self.format = Some(format.into());
481            self
482        }
483
484        /// Whether to include the [license codes] of the instance in the token.
485        ///
486        /// Specify `true` to include this information or `false` to omit this information from the payload.
487        /// The default value is `false`. Has no effect unless format is `full`.
488        ///
489        /// [license codes]: https://cloud.google.com/compute/docs/reference/rest/v1/images/get#body.Image.FIELDS.license_code
490        pub fn with_licenses(mut self, licenses: bool) -> Self {
491            self.licenses = if licenses {
492                Some("TRUE".to_string())
493            } else {
494                Some("FALSE".to_string())
495            };
496            self
497        }
498
499        fn build_token_provider(self) -> MDSTokenProvider {
500            let final_endpoint: String;
501
502            // Determine the endpoint and whether it was overridden
503            if let Ok(host_from_env) = std::env::var(GCE_METADATA_HOST_ENV_VAR) {
504                // Check GCE_METADATA_HOST environment variable first
505                final_endpoint = format!("http://{host_from_env}");
506            } else if let Some(builder_endpoint) = self.endpoint {
507                // Else, check if an endpoint was provided to the mds::Builder
508                final_endpoint = builder_endpoint;
509            } else {
510                // Else, use the default metadata root
511                final_endpoint = METADATA_ROOT.to_string();
512            };
513
514            MDSTokenProvider {
515                format: self.format,
516                licenses: self.licenses,
517                endpoint: final_endpoint,
518                target_audience: self.target_audience,
519            }
520        }
521
522        /// Returns an [`IDTokenCredentials`] instance with the configured
523        /// settings.
524        pub fn build(self) -> BuildResult<IDTokenCredentials> {
525            let creds = MDSCredentials {
526                token_provider: self.build_token_provider(),
527            };
528            Ok(IDTokenCredentials {
529                inner: Arc::new(creds),
530            })
531        }
532    }
533
534    #[derive(Debug, Clone, Default)]
535    struct MDSTokenProvider {
536        endpoint: String,
537        format: Option<String>,
538        licenses: Option<String>,
539        target_audience: String,
540    }
541
542    #[async_trait]
543    impl TokenProvider for MDSTokenProvider {
544        async fn token(&self) -> Result<Token> {
545            let client = Client::new();
546            let audience = self.target_audience.clone();
547            let request = client
548                .get(format!("{}{}/identity", self.endpoint, MDS_DEFAULT_URI))
549                .header(
550                    METADATA_FLAVOR,
551                    HeaderValue::from_static(METADATA_FLAVOR_VALUE),
552                )
553                .query(&[("audience", audience)]);
554            let request = self.format.iter().fold(request, |builder, format| {
555                builder.query(&[("format", format)])
556            });
557            let request = self.licenses.iter().fold(request, |builder, licenses| {
558                builder.query(&[("licenses", licenses)])
559            });
560
561            let response = request
562                .send()
563                .await
564                .map_err(|e| crate::errors::from_http_error(e, "failed to fetch token"))?;
565
566            if !response.status().is_success() {
567                let err =
568                    crate::errors::from_http_response(response, "failed to fetch token").await;
569                return Err(err);
570            }
571
572            let token = response
573                .text()
574                .await
575                .map_err(|e| CredentialsError::from_source(!e.is_decode(), e))?;
576
577            Ok(Token {
578                token,
579                token_type: "Bearer".to_string(),
580                // ID tokens from MDS do not have an expiry.
581                expires_at: None,
582                metadata: None,
583            })
584        }
585    }
586}
587
588#[cfg(test)]
589mod tests {
590    use super::idtoken;
591    use super::*;
592    use crate::credentials::DEFAULT_UNIVERSE_DOMAIN;
593    use crate::credentials::QUOTA_PROJECT_KEY;
594    use crate::credentials::tests::{
595        find_source_error, get_headers_from_cache, get_mock_auth_retry_policy,
596        get_mock_backoff_policy, get_mock_retry_throttler, get_token_from_headers,
597        get_token_type_from_headers,
598    };
599    use crate::errors;
600    use crate::errors::CredentialsError;
601    use crate::token::tests::MockTokenProvider;
602    use http::HeaderValue;
603    use http::header::AUTHORIZATION;
604    use httptest::cycle;
605    use httptest::matchers::{all_of, contains, request, url_decoded};
606    use httptest::responders::{json_encoded, status_code};
607    use httptest::{Expectation, Server};
608    use reqwest::StatusCode;
609    use scoped_env::ScopedEnv;
610    use serial_test::{parallel, serial};
611    use std::error::Error;
612    use test_case::test_case;
613    use url::Url;
614
615    type TestResult = anyhow::Result<()>;
616
617    #[tokio::test]
618    #[parallel]
619    async fn test_mds_retries_on_transient_failures() -> TestResult {
620        let mut server = Server::run();
621        server.expect(
622            Expectation::matching(request::path(format!("{MDS_DEFAULT_URI}/token")))
623                .times(3)
624                .respond_with(status_code(503)),
625        );
626
627        let provider = Builder::default()
628            .with_endpoint(format!("http://{}", server.addr()))
629            .with_retry_policy(get_mock_auth_retry_policy(3))
630            .with_backoff_policy(get_mock_backoff_policy())
631            .with_retry_throttler(get_mock_retry_throttler())
632            .build_token_provider();
633
634        let err = provider.token().await.unwrap_err();
635        assert!(!err.is_transient());
636        server.verify_and_clear();
637        Ok(())
638    }
639
640    #[tokio::test]
641    #[parallel]
642    async fn test_mds_does_not_retry_on_non_transient_failures() -> TestResult {
643        let mut server = Server::run();
644        server.expect(
645            Expectation::matching(request::path(format!("{MDS_DEFAULT_URI}/token")))
646                .times(1)
647                .respond_with(status_code(401)),
648        );
649
650        let provider = Builder::default()
651            .with_endpoint(format!("http://{}", server.addr()))
652            .with_retry_policy(get_mock_auth_retry_policy(1))
653            .with_backoff_policy(get_mock_backoff_policy())
654            .with_retry_throttler(get_mock_retry_throttler())
655            .build_token_provider();
656
657        let err = provider.token().await.unwrap_err();
658        assert!(!err.is_transient());
659        server.verify_and_clear();
660        Ok(())
661    }
662
663    #[tokio::test]
664    #[parallel]
665    async fn test_mds_retries_for_success() -> TestResult {
666        let mut server = Server::run();
667        let response = MDSTokenResponse {
668            access_token: "test-access-token".to_string(),
669            expires_in: Some(3600),
670            token_type: "test-token-type".to_string(),
671        };
672
673        server.expect(
674            Expectation::matching(request::path(format!("{MDS_DEFAULT_URI}/token")))
675                .times(3)
676                .respond_with(cycle![
677                    status_code(503).body("try-again"),
678                    status_code(503).body("try-again"),
679                    status_code(200)
680                        .append_header("Content-Type", "application/json")
681                        .body(serde_json::to_string(&response).unwrap()),
682                ]),
683        );
684
685        let provider = Builder::default()
686            .with_endpoint(format!("http://{}", server.addr()))
687            .with_retry_policy(get_mock_auth_retry_policy(3))
688            .with_backoff_policy(get_mock_backoff_policy())
689            .with_retry_throttler(get_mock_retry_throttler())
690            .build_token_provider();
691
692        let token = provider.token().await?;
693        assert_eq!(token.token, "test-access-token");
694
695        server.verify_and_clear();
696        Ok(())
697    }
698
699    #[test]
700    fn validate_default_endpoint_urls() {
701        let default_endpoint_address = Url::parse(&format!("{METADATA_ROOT}{MDS_DEFAULT_URI}"));
702        assert!(default_endpoint_address.is_ok());
703
704        let token_endpoint_address = Url::parse(&format!("{METADATA_ROOT}{MDS_DEFAULT_URI}/token"));
705        assert!(token_endpoint_address.is_ok());
706    }
707
708    #[tokio::test]
709    async fn headers_success() -> TestResult {
710        let token = Token {
711            token: "test-token".to_string(),
712            token_type: "Bearer".to_string(),
713            expires_at: None,
714            metadata: None,
715        };
716
717        let mut mock = MockTokenProvider::new();
718        mock.expect_token().times(1).return_once(|| Ok(token));
719
720        let mdsc = MDSCredentials {
721            quota_project_id: None,
722            token_provider: TokenCache::new(mock),
723        };
724
725        let mut extensions = Extensions::new();
726        let cached_headers = mdsc.headers(extensions.clone()).await.unwrap();
727        let (headers, entity_tag) = match cached_headers {
728            CacheableResource::New { entity_tag, data } => (data, entity_tag),
729            CacheableResource::NotModified => unreachable!("expecting new headers"),
730        };
731        let token = headers.get(AUTHORIZATION).unwrap();
732        assert_eq!(headers.len(), 1, "{headers:?}");
733        assert_eq!(token, HeaderValue::from_static("Bearer test-token"));
734        assert!(token.is_sensitive());
735
736        extensions.insert(entity_tag);
737
738        let cached_headers = mdsc.headers(extensions).await?;
739
740        match cached_headers {
741            CacheableResource::New { .. } => unreachable!("expecting new headers"),
742            CacheableResource::NotModified => CacheableResource::<HeaderMap>::NotModified,
743        };
744        Ok(())
745    }
746
747    #[tokio::test]
748    async fn headers_failure() {
749        let mut mock = MockTokenProvider::new();
750        mock.expect_token()
751            .times(1)
752            .return_once(|| Err(errors::non_retryable_from_str("fail")));
753
754        let mdsc = MDSCredentials {
755            quota_project_id: None,
756            token_provider: TokenCache::new(mock),
757        };
758        assert!(mdsc.headers(Extensions::new()).await.is_err());
759    }
760
761    #[test]
762    fn error_message_with_adc() {
763        let provider = MDSAccessTokenProvider::builder()
764            .endpoint("http://127.0.0.1")
765            .created_by_adc(true)
766            .endpoint_overridden(false)
767            .build();
768
769        let want = MDS_NOT_FOUND_ERROR;
770        let got = provider.error_message();
771        assert!(got.contains(want), "{got}, {provider:?}");
772    }
773
774    #[test_case(false, false)]
775    #[test_case(false, true)]
776    #[test_case(true, true)]
777    fn error_message_without_adc(adc: bool, overridden: bool) {
778        let provider = MDSAccessTokenProvider::builder()
779            .endpoint("http://127.0.0.1")
780            .created_by_adc(adc)
781            .endpoint_overridden(overridden)
782            .build();
783
784        let not_want = MDS_NOT_FOUND_ERROR;
785        let got = provider.error_message();
786        assert!(!got.contains(not_want), "{got}, {provider:?}");
787    }
788
789    #[tokio::test]
790    #[serial]
791    async fn adc_no_mds() -> TestResult {
792        let Err(err) = Builder::from_adc().build_token_provider().token().await else {
793            // The environment has an MDS, skip the test.
794            return Ok(());
795        };
796
797        let original_err = find_source_error::<CredentialsError>(&err).unwrap();
798        assert!(
799            original_err.to_string().contains("application-default"),
800            "display={err}, debug={err:?}"
801        );
802
803        Ok(())
804    }
805
806    #[tokio::test]
807    #[serial]
808    async fn adc_overridden_mds() -> TestResult {
809        let _e = ScopedEnv::set(super::GCE_METADATA_HOST_ENV_VAR, "metadata.overridden");
810
811        let err = Builder::from_adc()
812            .build_token_provider()
813            .token()
814            .await
815            .unwrap_err();
816
817        let _e = ScopedEnv::remove(super::GCE_METADATA_HOST_ENV_VAR);
818
819        let original_err = find_source_error::<CredentialsError>(&err).unwrap();
820        assert!(original_err.is_transient());
821        assert!(
822            !original_err.to_string().contains("application-default"),
823            "display={err}, debug={err:?}"
824        );
825        let source = find_source_error::<reqwest::Error>(&err);
826        assert!(matches!(source, Some(e) if e.status().is_none()), "{err:?}");
827
828        Ok(())
829    }
830
831    #[tokio::test]
832    #[serial]
833    async fn builder_no_mds() -> TestResult {
834        let Err(e) = Builder::default().build_token_provider().token().await else {
835            // The environment has an MDS, skip the test.
836            return Ok(());
837        };
838
839        let original_err = find_source_error::<CredentialsError>(&e).unwrap();
840        assert!(
841            !format!("{:?}", original_err.source()).contains("application-default"),
842            "{e:?}"
843        );
844
845        Ok(())
846    }
847
848    #[tokio::test]
849    #[serial]
850    async fn test_gce_metadata_host_env_var() -> TestResult {
851        let server = Server::run();
852        let scopes = ["scope1", "scope2"];
853        let response = MDSTokenResponse {
854            access_token: "test-access-token".to_string(),
855            expires_in: Some(3600),
856            token_type: "test-token-type".to_string(),
857        };
858        server.expect(
859            Expectation::matching(all_of![
860                request::path(format!("{MDS_DEFAULT_URI}/token")),
861                request::query(url_decoded(contains(("scopes", scopes.join(",")))))
862            ])
863            .respond_with(json_encoded(response)),
864        );
865
866        let addr = server.addr().to_string();
867        let _e = ScopedEnv::set(super::GCE_METADATA_HOST_ENV_VAR, &addr);
868        let mdsc = Builder::default()
869            .with_scopes(["scope1", "scope2"])
870            .build()
871            .unwrap();
872        let headers = mdsc.headers(Extensions::new()).await.unwrap();
873        let _e = ScopedEnv::remove(super::GCE_METADATA_HOST_ENV_VAR);
874
875        assert_eq!(
876            get_token_from_headers(headers).unwrap(),
877            "test-access-token"
878        );
879        Ok(())
880    }
881
882    #[tokio::test]
883    #[parallel]
884    async fn headers_success_with_quota_project() -> TestResult {
885        let server = Server::run();
886        let scopes = ["scope1", "scope2"];
887        let response = MDSTokenResponse {
888            access_token: "test-access-token".to_string(),
889            expires_in: Some(3600),
890            token_type: "test-token-type".to_string(),
891        };
892        server.expect(
893            Expectation::matching(all_of![
894                request::path(format!("{MDS_DEFAULT_URI}/token")),
895                request::query(url_decoded(contains(("scopes", scopes.join(",")))))
896            ])
897            .respond_with(json_encoded(response)),
898        );
899
900        let mdsc = Builder::default()
901            .with_scopes(["scope1", "scope2"])
902            .with_endpoint(format!("http://{}", server.addr()))
903            .with_quota_project_id("test-project")
904            .build()?;
905
906        let headers = get_headers_from_cache(mdsc.headers(Extensions::new()).await.unwrap())?;
907        let token = headers.get(AUTHORIZATION).unwrap();
908        let quota_project = headers.get(QUOTA_PROJECT_KEY).unwrap();
909
910        assert_eq!(headers.len(), 2, "{headers:?}");
911        assert_eq!(
912            token,
913            HeaderValue::from_static("test-token-type test-access-token")
914        );
915        assert!(token.is_sensitive());
916        assert_eq!(quota_project, HeaderValue::from_static("test-project"));
917        assert!(!quota_project.is_sensitive());
918
919        Ok(())
920    }
921
922    #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
923    #[parallel]
924    async fn token_caching() -> TestResult {
925        let mut server = Server::run();
926        let scopes = vec!["scope1".to_string()];
927        let response = MDSTokenResponse {
928            access_token: "test-access-token".to_string(),
929            expires_in: Some(3600),
930            token_type: "test-token-type".to_string(),
931        };
932        server.expect(
933            Expectation::matching(all_of![
934                request::path(format!("{MDS_DEFAULT_URI}/token")),
935                request::query(url_decoded(contains(("scopes", scopes.join(",")))))
936            ])
937            .times(1)
938            .respond_with(json_encoded(response)),
939        );
940
941        let mdsc = Builder::default()
942            .with_scopes(scopes)
943            .with_endpoint(format!("http://{}", server.addr()))
944            .build()?;
945        let headers = mdsc.headers(Extensions::new()).await?;
946        assert_eq!(
947            get_token_from_headers(headers).unwrap(),
948            "test-access-token"
949        );
950        let headers = mdsc.headers(Extensions::new()).await?;
951        assert_eq!(
952            get_token_from_headers(headers).unwrap(),
953            "test-access-token"
954        );
955
956        // validate that the inner token provider is called only once
957        server.verify_and_clear();
958
959        Ok(())
960    }
961
962    #[tokio::test(start_paused = true)]
963    #[parallel]
964    async fn token_provider_full() -> TestResult {
965        let server = Server::run();
966        let scopes = vec!["scope1".to_string()];
967        let response = MDSTokenResponse {
968            access_token: "test-access-token".to_string(),
969            expires_in: Some(3600),
970            token_type: "test-token-type".to_string(),
971        };
972        server.expect(
973            Expectation::matching(all_of![
974                request::path(format!("{MDS_DEFAULT_URI}/token")),
975                request::query(url_decoded(contains(("scopes", scopes.join(",")))))
976            ])
977            .respond_with(json_encoded(response)),
978        );
979
980        let token = Builder::default()
981            .with_endpoint(format!("http://{}", server.addr()))
982            .with_scopes(scopes)
983            .build_token_provider()
984            .token()
985            .await?;
986
987        let now = tokio::time::Instant::now();
988        assert_eq!(token.token, "test-access-token");
989        assert_eq!(token.token_type, "test-token-type");
990        assert!(
991            token
992                .expires_at
993                .is_some_and(|d| d >= now + Duration::from_secs(3600))
994        );
995
996        Ok(())
997    }
998
999    #[tokio::test(start_paused = true)]
1000    #[parallel]
1001    async fn token_provider_full_no_scopes() -> TestResult {
1002        let server = Server::run();
1003        let response = MDSTokenResponse {
1004            access_token: "test-access-token".to_string(),
1005            expires_in: Some(3600),
1006            token_type: "test-token-type".to_string(),
1007        };
1008        server.expect(
1009            Expectation::matching(request::path(format!("{MDS_DEFAULT_URI}/token")))
1010                .respond_with(json_encoded(response)),
1011        );
1012
1013        let token = Builder::default()
1014            .with_endpoint(format!("http://{}", server.addr()))
1015            .build_token_provider()
1016            .token()
1017            .await?;
1018
1019        let now = Instant::now();
1020        assert_eq!(token.token, "test-access-token");
1021        assert_eq!(token.token_type, "test-token-type");
1022        assert!(
1023            token
1024                .expires_at
1025                .is_some_and(|d| d == now + Duration::from_secs(3600))
1026        );
1027
1028        Ok(())
1029    }
1030
1031    #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
1032    #[parallel]
1033    async fn credential_provider_full() -> TestResult {
1034        let server = Server::run();
1035        let scopes = vec!["scope1".to_string()];
1036        let response = MDSTokenResponse {
1037            access_token: "test-access-token".to_string(),
1038            expires_in: None,
1039            token_type: "test-token-type".to_string(),
1040        };
1041        server.expect(
1042            Expectation::matching(all_of![
1043                request::path(format!("{MDS_DEFAULT_URI}/token")),
1044                request::query(url_decoded(contains(("scopes", scopes.join(",")))))
1045            ])
1046            .respond_with(json_encoded(response)),
1047        );
1048
1049        let mdsc = Builder::default()
1050            .with_endpoint(format!("http://{}", server.addr()))
1051            .with_scopes(scopes)
1052            .build()?;
1053        let headers = mdsc.headers(Extensions::new()).await?;
1054        assert_eq!(
1055            get_token_from_headers(headers.clone()).unwrap(),
1056            "test-access-token"
1057        );
1058        assert_eq!(
1059            get_token_type_from_headers(headers).unwrap(),
1060            "test-token-type"
1061        );
1062
1063        Ok(())
1064    }
1065
1066    #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
1067    #[parallel]
1068    async fn credentials_headers_retryable_error() -> TestResult {
1069        let server = Server::run();
1070        let scopes = vec!["scope1".to_string()];
1071        server.expect(
1072            Expectation::matching(all_of![
1073                request::path(format!("{MDS_DEFAULT_URI}/token")),
1074                request::query(url_decoded(contains(("scopes", scopes.join(",")))))
1075            ])
1076            .respond_with(status_code(503)),
1077        );
1078
1079        let mdsc = Builder::default()
1080            .with_endpoint(format!("http://{}", server.addr()))
1081            .with_scopes(scopes)
1082            .build()?;
1083        let err = mdsc.headers(Extensions::new()).await.unwrap_err();
1084        let original_err = find_source_error::<CredentialsError>(&err).unwrap();
1085        assert!(original_err.is_transient());
1086        let source = find_source_error::<reqwest::Error>(&err);
1087        assert!(
1088            matches!(source, Some(e) if e.status() == Some(StatusCode::SERVICE_UNAVAILABLE)),
1089            "{err:?}"
1090        );
1091
1092        Ok(())
1093    }
1094
1095    #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
1096    #[parallel]
1097    async fn credentials_headers_nonretryable_error() -> TestResult {
1098        let server = Server::run();
1099        let scopes = vec!["scope1".to_string()];
1100        server.expect(
1101            Expectation::matching(all_of![
1102                request::path(format!("{MDS_DEFAULT_URI}/token")),
1103                request::query(url_decoded(contains(("scopes", scopes.join(",")))))
1104            ])
1105            .respond_with(status_code(401)),
1106        );
1107
1108        let mdsc = Builder::default()
1109            .with_endpoint(format!("http://{}", server.addr()))
1110            .with_scopes(scopes)
1111            .build()?;
1112
1113        let err = mdsc.headers(Extensions::new()).await.unwrap_err();
1114        let original_err = find_source_error::<CredentialsError>(&err).unwrap();
1115        assert!(!original_err.is_transient());
1116        let source = find_source_error::<reqwest::Error>(&err);
1117        assert!(
1118            matches!(source, Some(e) if e.status() == Some(StatusCode::UNAUTHORIZED)),
1119            "{err:?}"
1120        );
1121
1122        Ok(())
1123    }
1124
1125    #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
1126    #[parallel]
1127    async fn credentials_headers_malformed_response_is_nonretryable() -> TestResult {
1128        let server = Server::run();
1129        let scopes = vec!["scope1".to_string()];
1130        server.expect(
1131            Expectation::matching(all_of![
1132                request::path(format!("{MDS_DEFAULT_URI}/token")),
1133                request::query(url_decoded(contains(("scopes", scopes.join(",")))))
1134            ])
1135            .respond_with(json_encoded("bad json")),
1136        );
1137
1138        let mdsc = Builder::default()
1139            .with_endpoint(format!("http://{}", server.addr()))
1140            .with_scopes(scopes)
1141            .build()?;
1142
1143        let e = mdsc.headers(Extensions::new()).await.err().unwrap();
1144        assert!(!e.is_transient());
1145
1146        Ok(())
1147    }
1148
1149    #[tokio::test]
1150    async fn get_default_universe_domain_success() -> TestResult {
1151        let universe_domain_response = Builder::default().build()?.universe_domain().await.unwrap();
1152        assert_eq!(universe_domain_response, DEFAULT_UNIVERSE_DOMAIN);
1153        Ok(())
1154    }
1155
1156    #[tokio::test]
1157    #[parallel]
1158    async fn test_idtoken_builder_build() -> TestResult {
1159        let server = Server::run();
1160        let audience = "test-audience";
1161        let format = "format";
1162        let token_string = "test-id-token";
1163        server.expect(
1164            Expectation::matching(all_of![
1165                request::path(format!("{MDS_DEFAULT_URI}/identity")),
1166                request::query(url_decoded(contains(("audience", audience)))),
1167                request::query(url_decoded(contains(("format", format)))),
1168                request::query(url_decoded(contains(("licenses", "TRUE"))))
1169            ])
1170            .respond_with(status_code(200).body(token_string)),
1171        );
1172
1173        let creds = idtoken::Builder::new(audience)
1174            .with_endpoint(format!("http://{}", server.addr()))
1175            .with_format(format)
1176            .with_licenses(true)
1177            .build()?;
1178
1179        let token = creds.id_token().await?;
1180        assert_eq!(token.token, token_string);
1181        assert_eq!(token.token_type, "Bearer");
1182        assert!(token.expires_at.is_none());
1183        Ok(())
1184    }
1185
1186    #[tokio::test]
1187    #[serial]
1188    async fn test_idtoken_builder_build_with_env_var() -> TestResult {
1189        let server = Server::run();
1190        let audience = "test-audience";
1191        let token_string = "test-id-token";
1192        server.expect(
1193            Expectation::matching(all_of![
1194                request::path(format!("{MDS_DEFAULT_URI}/identity")),
1195                request::query(url_decoded(contains(("audience", audience))))
1196            ])
1197            .respond_with(status_code(200).body(token_string)),
1198        );
1199
1200        let addr = server.addr().to_string();
1201        let _e = ScopedEnv::set(super::GCE_METADATA_HOST_ENV_VAR, &addr);
1202
1203        let creds = idtoken::Builder::new(audience).build()?;
1204
1205        let token = creds.id_token().await?;
1206        assert_eq!(token.token, token_string);
1207
1208        let _e = ScopedEnv::remove(super::GCE_METADATA_HOST_ENV_VAR);
1209        Ok(())
1210    }
1211
1212    #[tokio::test]
1213    #[parallel]
1214    async fn test_idtoken_provider_http_error() -> TestResult {
1215        let server = Server::run();
1216        let audience = "test-audience";
1217        server.expect(
1218            Expectation::matching(all_of![
1219                request::path(format!("{MDS_DEFAULT_URI}/identity")),
1220                request::query(url_decoded(contains(("audience", audience))))
1221            ])
1222            .respond_with(status_code(503)),
1223        );
1224
1225        let creds = idtoken::Builder::new(audience)
1226            .with_endpoint(format!("http://{}", server.addr()))
1227            .build()?;
1228
1229        let err = creds.id_token().await.unwrap_err();
1230        let source = find_source_error::<reqwest::Error>(&err);
1231        assert!(
1232            matches!(source, Some(e) if e.status() == Some(StatusCode::SERVICE_UNAVAILABLE)),
1233            "{err:?}"
1234        );
1235        Ok(())
1236    }
1237}