google_cloud_auth/credentials/
mds.rs

1// Copyright 2024 Google LLC
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     https://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15//! [Metadata Service] Credentials type.
16//!
17//! Google Cloud environments such as [Google Compute Engine (GCE)][gce-link],
18//! [Google Kubernetes Engine (GKE)][gke-link], or [Cloud Run] provide a metadata service.
19//! This is a local service to the VM (or pod) which (as the name implies) provides
20//! metadata information about the VM. The service also provides access
21//! tokens associated with the [default service account] for the corresponding
22//! VM.
23//!
24//! The default host name of the metadata service is `metadata.google.internal`.
25//! If you would like to use a different hostname, you can set it using the
26//! `GCE_METADATA_HOST` environment variable.
27//!
28//! You can use this access token to securely authenticate with Google Cloud,
29//! without having to download secrets or other credentials. The types in this
30//! module allow you to retrieve these access tokens, and can be used with
31//! the Google Cloud client libraries for Rust.
32//!
33//! ## Example: Creating credentials with a custom quota project
34//!
35//! ```
36//! # use google_cloud_auth::credentials::mds::Builder;
37//! # use google_cloud_auth::credentials::Credentials;
38//! # use http::Extensions;
39//! # tokio_test::block_on(async {
40//! let credentials: Credentials = Builder::default()
41//!     .with_quota_project_id("my-quota-project")
42//!     .build()?;
43//! let headers = credentials.headers(Extensions::new()).await?;
44//! println!("Headers: {headers:?}");
45//! # Ok::<(), anyhow::Error>(())
46//! # });
47//! ```
48//!
49//! ## Example: Creating credentials with custom retry behavior
50//!
51//! ```
52//! # use google_cloud_auth::credentials::mds::Builder;
53//! # use google_cloud_auth::credentials::Credentials;
54//! # use http::Extensions;
55//! # use std::time::Duration;
56//! # tokio_test::block_on(async {
57//! use gax::retry_policy::{AlwaysRetry, RetryPolicyExt};
58//! use gax::exponential_backoff::ExponentialBackoff;
59//! let backoff = ExponentialBackoff::default();
60//! let credentials: Credentials = Builder::default()
61//!     .with_retry_policy(AlwaysRetry.with_attempt_limit(3))
62//!     .with_backoff_policy(backoff)
63//!     .build()?;
64//! let headers = credentials.headers(Extensions::new()).await?;
65//! println!("Headers: {headers:?}");
66//! # Ok::<(), anyhow::Error>(())
67//! # });
68//! ```
69//!
70//! [Application Default Credentials]: https://cloud.google.com/docs/authentication/application-default-credentials
71//! [Cloud Run]: https://cloud.google.com/run
72//! [default service account]: https://cloud.google.com/iam/docs/service-account-types#default
73//! [gce-link]: https://cloud.google.com/products/compute
74//! [gke-link]: https://cloud.google.com/kubernetes-engine
75//! [Metadata Service]: https://cloud.google.com/compute/docs/metadata/overview
76
77use crate::credentials::dynamic::CredentialsProvider;
78use crate::credentials::{CacheableResource, Credentials};
79use crate::errors::CredentialsError;
80use crate::headers_util::build_cacheable_headers;
81use crate::retry::{Builder as RetryTokenProviderBuilder, TokenProviderWithRetry};
82use crate::token::{CachedTokenProvider, Token, TokenProvider};
83use crate::token_cache::TokenCache;
84use crate::{BuildResult, Result};
85use async_trait::async_trait;
86use bon::Builder;
87use gax::backoff_policy::BackoffPolicyArg;
88use gax::retry_policy::RetryPolicyArg;
89use gax::retry_throttler::RetryThrottlerArg;
90use http::{Extensions, HeaderMap, HeaderValue};
91use reqwest::Client;
92use std::default::Default;
93use std::sync::Arc;
94use std::time::Duration;
95use tokio::time::Instant;
96
97const METADATA_FLAVOR_VALUE: &str = "Google";
98const METADATA_FLAVOR: &str = "metadata-flavor";
99const METADATA_ROOT: &str = "http://metadata.google.internal";
100const MDS_DEFAULT_URI: &str = "/computeMetadata/v1/instance/service-accounts/default";
101const GCE_METADATA_HOST_ENV_VAR: &str = "GCE_METADATA_HOST";
102// TODO(#2235) - Improve this message by talking about retries when really running with MDS
103const MDS_NOT_FOUND_ERROR: &str = concat!(
104    "Could not fetch an auth token to authenticate with Google Cloud. ",
105    "The most common reason for this problem is that you are not running in a Google Cloud Environment ",
106    "and you have not configured local credentials for development and testing. ",
107    "To setup local credentials, run `gcloud auth application-default login`. ",
108    "More information on how to authenticate client libraries can be found at https://cloud.google.com/docs/authentication/client-libraries"
109);
110
111#[derive(Debug)]
112struct MDSCredentials<T>
113where
114    T: CachedTokenProvider,
115{
116    quota_project_id: Option<String>,
117    token_provider: T,
118}
119
120/// Creates [Credentials] instances backed by the [Metadata Service].
121///
122/// While the Google Cloud client libraries for Rust default to credentials
123/// backed by the metadata service, some applications may need to:
124/// * Customize the metadata service credentials in some way
125/// * Bypass the [Application Default Credentials] lookup and only
126///   use the metadata server credentials
127/// * Use the credentials directly outside the client libraries
128///
129/// [Application Default Credentials]: https://cloud.google.com/docs/authentication/application-default-credentials
130/// [Metadata Service]: https://cloud.google.com/compute/docs/metadata/overview
131#[derive(Debug, Default)]
132pub struct Builder {
133    endpoint: Option<String>,
134    quota_project_id: Option<String>,
135    scopes: Option<Vec<String>>,
136    created_by_adc: bool,
137    retry_builder: RetryTokenProviderBuilder,
138}
139
140impl Builder {
141    /// Sets the endpoint for this credentials.
142    ///
143    /// A trailing slash is significant, so specify the base URL without a trailing  
144    /// slash. If not set, the credentials use `http://metadata.google.internal`.
145    ///
146    /// # Example
147    /// ```
148    /// # use google_cloud_auth::credentials::mds::Builder;
149    /// # tokio_test::block_on(async {
150    /// let credentials = Builder::default()
151    ///     .with_endpoint("https://metadata.google.foobar")
152    ///     .build();
153    /// # });
154    /// ```
155    pub fn with_endpoint<S: Into<String>>(mut self, endpoint: S) -> Self {
156        self.endpoint = Some(endpoint.into());
157        self
158    }
159
160    /// Set the [quota project] for this credentials.
161    ///
162    /// In some services, you can use a service account in
163    /// one project for authentication and authorization, and charge
164    /// the usage to a different project. This may require that the
165    /// service account has `serviceusage.services.use` permissions on the quota project.
166    ///
167    /// [quota project]: https://cloud.google.com/docs/quotas/quota-project
168    pub fn with_quota_project_id<S: Into<String>>(mut self, quota_project_id: S) -> Self {
169        self.quota_project_id = Some(quota_project_id.into());
170        self
171    }
172
173    /// Sets the [scopes] for this credentials.
174    ///
175    /// Metadata server issues tokens based on the requested scopes.
176    /// If no scopes are specified, the credentials defaults to all
177    /// scopes configured for the [default service account] on the instance.
178    ///
179    /// [default service account]: https://cloud.google.com/iam/docs/service-account-types#default
180    /// [scopes]: https://developers.google.com/identity/protocols/oauth2/scopes
181    pub fn with_scopes<I, S>(mut self, scopes: I) -> Self
182    where
183        I: IntoIterator<Item = S>,
184        S: Into<String>,
185    {
186        self.scopes = Some(scopes.into_iter().map(|s| s.into()).collect());
187        self
188    }
189
190    /// Configure the retry policy for fetching tokens.
191    ///
192    /// The retry policy controls how to handle retries, and sets limits on
193    /// the number of attempts or the total time spent retrying.
194    ///
195    /// ```
196    /// # use google_cloud_auth::credentials::mds::Builder;
197    /// # tokio_test::block_on(async {
198    /// use gax::retry_policy::{AlwaysRetry, RetryPolicyExt};
199    /// let credentials = Builder::default()
200    ///     .with_retry_policy(AlwaysRetry.with_attempt_limit(3))
201    ///     .build();
202    /// # });
203    /// ```
204    pub fn with_retry_policy<V: Into<RetryPolicyArg>>(mut self, v: V) -> Self {
205        self.retry_builder = self.retry_builder.with_retry_policy(v.into());
206        self
207    }
208
209    /// Configure the retry backoff policy.
210    ///
211    /// The backoff policy controls how long to wait in between retry attempts.
212    ///
213    /// ```
214    /// # use google_cloud_auth::credentials::mds::Builder;
215    /// # use std::time::Duration;
216    /// # tokio_test::block_on(async {
217    /// use gax::exponential_backoff::ExponentialBackoff;
218    /// let policy = ExponentialBackoff::default();
219    /// let credentials = Builder::default()
220    ///     .with_backoff_policy(policy)
221    ///     .build();
222    /// # });
223    /// ```
224    pub fn with_backoff_policy<V: Into<BackoffPolicyArg>>(mut self, v: V) -> Self {
225        self.retry_builder = self.retry_builder.with_backoff_policy(v.into());
226        self
227    }
228
229    /// Configure the retry throttler.
230    ///
231    /// Advanced applications may want to configure a retry throttler to
232    /// [Address Cascading Failures] and when [Handling Overload] conditions.
233    /// The authentication library throttles its retry loop, using a policy to
234    /// control the throttling algorithm. Use this method to fine tune or
235    /// customize the default retry throttler.
236    ///
237    /// [Handling Overload]: https://sre.google/sre-book/handling-overload/
238    /// [Address Cascading Failures]: https://sre.google/sre-book/addressing-cascading-failures/
239    ///
240    /// ```
241    /// # use google_cloud_auth::credentials::mds::Builder;
242    /// # tokio_test::block_on(async {
243    /// use gax::retry_throttler::AdaptiveThrottler;
244    /// let credentials = Builder::default()
245    ///     .with_retry_throttler(AdaptiveThrottler::default())
246    ///     .build();
247    /// # });
248    /// ```
249    pub fn with_retry_throttler<V: Into<RetryThrottlerArg>>(mut self, v: V) -> Self {
250        self.retry_builder = self.retry_builder.with_retry_throttler(v.into());
251        self
252    }
253
254    // This method is used to build mds credentials from ADC
255    pub(crate) fn from_adc() -> Self {
256        Self {
257            created_by_adc: true,
258            ..Default::default()
259        }
260    }
261
262    fn build_token_provider(self) -> TokenProviderWithRetry<MDSAccessTokenProvider> {
263        let final_endpoint: String;
264        let endpoint_overridden: bool;
265
266        // Determine the endpoint and whether it was overridden
267        if let Ok(host_from_env) = std::env::var(GCE_METADATA_HOST_ENV_VAR) {
268            // Check GCE_METADATA_HOST environment variable first
269            final_endpoint = format!("http://{host_from_env}");
270            endpoint_overridden = true;
271        } else if let Some(builder_endpoint) = self.endpoint {
272            // Else, check if an endpoint was provided to the mds::Builder
273            final_endpoint = builder_endpoint;
274            endpoint_overridden = true;
275        } else {
276            // Else, use the default metadata root
277            final_endpoint = METADATA_ROOT.to_string();
278            endpoint_overridden = false;
279        };
280
281        let tp = MDSAccessTokenProvider::builder()
282            .endpoint(final_endpoint)
283            .maybe_scopes(self.scopes)
284            .endpoint_overridden(endpoint_overridden)
285            .created_by_adc(self.created_by_adc)
286            .build();
287        self.retry_builder.build(tp)
288    }
289
290    /// Returns a [Credentials] instance with the configured settings.
291    pub fn build(self) -> BuildResult<Credentials> {
292        let mdsc = MDSCredentials {
293            quota_project_id: self.quota_project_id.clone(),
294            token_provider: TokenCache::new(self.build_token_provider()),
295        };
296        Ok(Credentials {
297            inner: Arc::new(mdsc),
298        })
299    }
300}
301
302#[async_trait::async_trait]
303impl<T> CredentialsProvider for MDSCredentials<T>
304where
305    T: CachedTokenProvider,
306{
307    async fn headers(&self, extensions: Extensions) -> Result<CacheableResource<HeaderMap>> {
308        let cached_token = self.token_provider.token(extensions).await?;
309        build_cacheable_headers(&cached_token, &self.quota_project_id)
310    }
311}
312
313#[derive(Clone, Debug, PartialEq, serde::Deserialize, serde::Serialize)]
314struct MDSTokenResponse {
315    access_token: String,
316    #[serde(skip_serializing_if = "Option::is_none")]
317    expires_in: Option<u64>,
318    token_type: String,
319}
320
321#[derive(Debug, Clone, Default, Builder)]
322struct MDSAccessTokenProvider {
323    #[builder(into)]
324    scopes: Option<Vec<String>>,
325    #[builder(into)]
326    endpoint: String,
327    endpoint_overridden: bool,
328    created_by_adc: bool,
329}
330
331impl MDSAccessTokenProvider {
332    // During ADC, if no credentials are found in the well-known location and the GOOGLE_APPLICATION_CREDENTIALS
333    // environment variable is not set, we default to MDS credentials without checking if the code is really
334    // running in an environment with MDS. To help users who got to this state because of lack of credentials
335    // setup on their machines, we provide a detailed error message to them talking about local setup and other
336    // auth mechanisms available to them.
337    // If the endpoint is overridden, even if ADC was used to create the MDS credentials, we do not give a detailed
338    // error message because they deliberately wanted to use an MDS.
339    fn error_message(&self) -> &str {
340        if self.use_adc_message() {
341            MDS_NOT_FOUND_ERROR
342        } else {
343            "failed to fetch token"
344        }
345    }
346
347    fn use_adc_message(&self) -> bool {
348        self.created_by_adc && !self.endpoint_overridden
349    }
350}
351
352#[async_trait]
353impl TokenProvider for MDSAccessTokenProvider {
354    async fn token(&self) -> Result<Token> {
355        let client = Client::new();
356        let request = client
357            .get(format!("{}{}/token", self.endpoint, MDS_DEFAULT_URI))
358            .header(
359                METADATA_FLAVOR,
360                HeaderValue::from_static(METADATA_FLAVOR_VALUE),
361            );
362        // Use the `scopes` option if set, otherwise let the MDS use the default
363        // scopes.
364        let scopes = self.scopes.as_ref().map(|v| v.join(","));
365        let request = scopes
366            .into_iter()
367            .fold(request, |r, s| r.query(&[("scopes", s)]));
368
369        // If the connection to MDS was not successful, it is useful to retry when really
370        // running on MDS environments and not useful if there is no MDS. We will mark the error
371        // as retryable and let the retry policy determine whether to retry or not. Whenever we
372        // define a default retry policy, we can skip retrying this case.
373        let response = request
374            .send()
375            .await
376            .map_err(|e| crate::errors::from_http_error(e, self.error_message()))?;
377        // Process the response
378        if !response.status().is_success() {
379            let err = crate::errors::from_http_response(response, self.error_message()).await;
380            return Err(err);
381        }
382        let response = response.json::<MDSTokenResponse>().await.map_err(|e| {
383            // Decoding errors are not transient. Typically they indicate a badly
384            // configured MDS endpoint, or DNS redirecting the request to a random
385            // server, e.g., ISPs that redirect unknown services to HTTP.
386            CredentialsError::from_source(!e.is_decode(), e)
387        })?;
388        let token = Token {
389            token: response.access_token,
390            token_type: response.token_type,
391            expires_at: response
392                .expires_in
393                .map(|d| Instant::now() + Duration::from_secs(d)),
394            metadata: None,
395        };
396        Ok(token)
397    }
398}
399
400#[cfg(google_cloud_unstable_id_token)]
401pub mod idtoken {
402    //! Types for fetching ID tokens from the metadata service.
403    use super::{
404        GCE_METADATA_HOST_ENV_VAR, MDS_DEFAULT_URI, METADATA_FLAVOR, METADATA_FLAVOR_VALUE,
405        METADATA_ROOT,
406    };
407    use crate::Result;
408    use crate::credentials::CacheableResource;
409    use crate::errors::CredentialsError;
410    use crate::token::{CachedTokenProvider, Token, TokenProvider};
411    use crate::token_cache::TokenCache;
412    use crate::{
413        BuildResult,
414        credentials::idtoken::dynamic::IDTokenCredentialsProvider,
415        credentials::idtoken::{IDTokenCredentials, parse_id_token_from_str},
416    };
417    use async_trait::async_trait;
418    use http::{Extensions, HeaderValue};
419    use reqwest::Client;
420    use std::sync::Arc;
421
422    #[derive(Debug)]
423    pub(crate) struct MDSCredentials<T>
424    where
425        T: CachedTokenProvider,
426    {
427        token_provider: T,
428    }
429
430    #[async_trait]
431    impl<T> IDTokenCredentialsProvider for MDSCredentials<T>
432    where
433        T: CachedTokenProvider,
434    {
435        async fn id_token(&self) -> Result<String> {
436            let cached_token = self.token_provider.token(Extensions::new()).await?;
437            match cached_token {
438                CacheableResource::New { data, .. } => Ok(data.token),
439                CacheableResource::NotModified => {
440                    Err(CredentialsError::from_msg(false, "failed to fetch token"))
441                }
442            }
443        }
444    }
445
446    /// Creates [`IDTokenCredentials`] instances that fetch ID tokens from the
447    /// metadata service.
448    #[derive(Debug, Default)]
449    pub struct Builder {
450        endpoint: Option<String>,
451        format: Option<String>,
452        licenses: Option<String>,
453        target_audience: String,
454    }
455
456    impl Builder {
457        /// Creates a new `Builder`.
458        ///
459        /// The `target_audience` is a required parameter that specifies the
460        /// intended audience of the ID token. This is typically the URL of the
461        /// service that will be receiving the token.
462        pub fn new<S: Into<String>>(target_audience: S) -> Self {
463            Builder {
464                format: None,
465                endpoint: None,
466                licenses: None,
467                target_audience: target_audience.into(),
468            }
469        }
470
471        /// Sets the endpoint for this credentials.
472        ///
473        /// A trailing slash is significant, so specify the base URL without a trailing  
474        /// slash. If not set, the credentials use `http://metadata.google.internal`.
475        pub fn with_endpoint<S: Into<String>>(mut self, endpoint: S) -> Self {
476            self.endpoint = Some(endpoint.into());
477            self
478        }
479
480        /// Sets the [format] of the token.
481        ///
482        /// Specifies whether or not the project and instance details are included in the payload.
483        /// Specify `full` to include this information in the payload or `standard` to omit the information
484        /// from the payload. The default value is `standard`.
485        ///
486        /// [format]: https://cloud.google.com/compute/docs/instances/verifying-instance-identity#token_format
487        pub fn with_format<S: Into<String>>(mut self, format: S) -> Self {
488            self.format = Some(format.into());
489            self
490        }
491
492        /// Whether to include the [license codes] of the instance in the token.
493        ///
494        /// Specify `true` to include this information or `false` to omit this information from the payload.
495        /// The default value is `false`. Has no effect unless format is `full`.
496        ///
497        /// [license codes]: https://cloud.google.com/compute/docs/reference/rest/v1/images/get#body.Image.FIELDS.license_code
498        pub fn with_licenses(mut self, licenses: bool) -> Self {
499            self.licenses = if licenses {
500                Some("TRUE".to_string())
501            } else {
502                Some("FALSE".to_string())
503            };
504            self
505        }
506
507        fn build_token_provider(self) -> MDSTokenProvider {
508            let final_endpoint: String;
509
510            // Determine the endpoint and whether it was overridden
511            if let Ok(host_from_env) = std::env::var(GCE_METADATA_HOST_ENV_VAR) {
512                // Check GCE_METADATA_HOST environment variable first
513                final_endpoint = format!("http://{host_from_env}");
514            } else if let Some(builder_endpoint) = self.endpoint {
515                // Else, check if an endpoint was provided to the mds::Builder
516                final_endpoint = builder_endpoint;
517            } else {
518                // Else, use the default metadata root
519                final_endpoint = METADATA_ROOT.to_string();
520            };
521
522            MDSTokenProvider {
523                format: self.format,
524                licenses: self.licenses,
525                endpoint: final_endpoint,
526                target_audience: self.target_audience,
527            }
528        }
529
530        /// Returns an [`IDTokenCredentials`] instance with the configured
531        /// settings.
532        pub fn build(self) -> BuildResult<IDTokenCredentials> {
533            let creds = MDSCredentials {
534                token_provider: TokenCache::new(self.build_token_provider()),
535            };
536            Ok(IDTokenCredentials {
537                inner: Arc::new(creds),
538            })
539        }
540    }
541
542    #[derive(Debug, Clone, Default)]
543    struct MDSTokenProvider {
544        endpoint: String,
545        format: Option<String>,
546        licenses: Option<String>,
547        target_audience: String,
548    }
549
550    #[async_trait]
551    impl TokenProvider for MDSTokenProvider {
552        async fn token(&self) -> Result<Token> {
553            let client = Client::new();
554            let audience = self.target_audience.clone();
555            let request = client
556                .get(format!("{}{}/identity", self.endpoint, MDS_DEFAULT_URI))
557                .header(
558                    METADATA_FLAVOR,
559                    HeaderValue::from_static(METADATA_FLAVOR_VALUE),
560                )
561                .query(&[("audience", audience)]);
562            let request = self.format.iter().fold(request, |builder, format| {
563                builder.query(&[("format", format)])
564            });
565            let request = self.licenses.iter().fold(request, |builder, licenses| {
566                builder.query(&[("licenses", licenses)])
567            });
568
569            let response = request
570                .send()
571                .await
572                .map_err(|e| crate::errors::from_http_error(e, "failed to fetch token"))?;
573
574            if !response.status().is_success() {
575                let err =
576                    crate::errors::from_http_response(response, "failed to fetch token").await;
577                return Err(err);
578            }
579
580            let token = response
581                .text()
582                .await
583                .map_err(|e| CredentialsError::from_source(!e.is_decode(), e))?;
584
585            parse_id_token_from_str(token)
586        }
587    }
588}
589
590#[cfg(test)]
591mod tests {
592    use super::*;
593    use crate::credentials::DEFAULT_UNIVERSE_DOMAIN;
594    use crate::credentials::QUOTA_PROJECT_KEY;
595    use crate::credentials::tests::{
596        find_source_error, get_headers_from_cache, get_mock_auth_retry_policy,
597        get_mock_backoff_policy, get_mock_retry_throttler, get_token_from_headers,
598        get_token_type_from_headers,
599    };
600    use crate::errors;
601    use crate::errors::CredentialsError;
602    use crate::token::tests::MockTokenProvider;
603    use http::HeaderValue;
604    use http::header::AUTHORIZATION;
605    use httptest::cycle;
606    use httptest::matchers::{all_of, contains, request, url_decoded};
607    use httptest::responders::{json_encoded, status_code};
608    use httptest::{Expectation, Server};
609    use reqwest::StatusCode;
610    use scoped_env::ScopedEnv;
611    use serial_test::{parallel, serial};
612    use std::error::Error;
613    use test_case::test_case;
614    use url::Url;
615
616    type TestResult = anyhow::Result<()>;
617
618    #[tokio::test]
619    #[parallel]
620    async fn test_mds_retries_on_transient_failures() -> TestResult {
621        let mut server = Server::run();
622        server.expect(
623            Expectation::matching(request::path(format!("{MDS_DEFAULT_URI}/token")))
624                .times(3)
625                .respond_with(status_code(503)),
626        );
627
628        let provider = Builder::default()
629            .with_endpoint(format!("http://{}", server.addr()))
630            .with_retry_policy(get_mock_auth_retry_policy(3))
631            .with_backoff_policy(get_mock_backoff_policy())
632            .with_retry_throttler(get_mock_retry_throttler())
633            .build_token_provider();
634
635        let err = provider.token().await.unwrap_err();
636        assert!(!err.is_transient());
637        server.verify_and_clear();
638        Ok(())
639    }
640
641    #[tokio::test]
642    #[parallel]
643    async fn test_mds_does_not_retry_on_non_transient_failures() -> TestResult {
644        let mut server = Server::run();
645        server.expect(
646            Expectation::matching(request::path(format!("{MDS_DEFAULT_URI}/token")))
647                .times(1)
648                .respond_with(status_code(401)),
649        );
650
651        let provider = Builder::default()
652            .with_endpoint(format!("http://{}", server.addr()))
653            .with_retry_policy(get_mock_auth_retry_policy(1))
654            .with_backoff_policy(get_mock_backoff_policy())
655            .with_retry_throttler(get_mock_retry_throttler())
656            .build_token_provider();
657
658        let err = provider.token().await.unwrap_err();
659        assert!(!err.is_transient());
660        server.verify_and_clear();
661        Ok(())
662    }
663
664    #[tokio::test]
665    #[parallel]
666    async fn test_mds_retries_for_success() -> TestResult {
667        let mut server = Server::run();
668        let response = MDSTokenResponse {
669            access_token: "test-access-token".to_string(),
670            expires_in: Some(3600),
671            token_type: "test-token-type".to_string(),
672        };
673
674        server.expect(
675            Expectation::matching(request::path(format!("{MDS_DEFAULT_URI}/token")))
676                .times(3)
677                .respond_with(cycle![
678                    status_code(503).body("try-again"),
679                    status_code(503).body("try-again"),
680                    status_code(200)
681                        .append_header("Content-Type", "application/json")
682                        .body(serde_json::to_string(&response).unwrap()),
683                ]),
684        );
685
686        let provider = Builder::default()
687            .with_endpoint(format!("http://{}", server.addr()))
688            .with_retry_policy(get_mock_auth_retry_policy(3))
689            .with_backoff_policy(get_mock_backoff_policy())
690            .with_retry_throttler(get_mock_retry_throttler())
691            .build_token_provider();
692
693        let token = provider.token().await?;
694        assert_eq!(token.token, "test-access-token");
695
696        server.verify_and_clear();
697        Ok(())
698    }
699
700    #[test]
701    fn validate_default_endpoint_urls() {
702        let default_endpoint_address = Url::parse(&format!("{METADATA_ROOT}{MDS_DEFAULT_URI}"));
703        assert!(default_endpoint_address.is_ok());
704
705        let token_endpoint_address = Url::parse(&format!("{METADATA_ROOT}{MDS_DEFAULT_URI}/token"));
706        assert!(token_endpoint_address.is_ok());
707    }
708
709    #[tokio::test]
710    async fn headers_success() -> TestResult {
711        let token = Token {
712            token: "test-token".to_string(),
713            token_type: "Bearer".to_string(),
714            expires_at: None,
715            metadata: None,
716        };
717
718        let mut mock = MockTokenProvider::new();
719        mock.expect_token().times(1).return_once(|| Ok(token));
720
721        let mdsc = MDSCredentials {
722            quota_project_id: None,
723            token_provider: TokenCache::new(mock),
724        };
725
726        let mut extensions = Extensions::new();
727        let cached_headers = mdsc.headers(extensions.clone()).await.unwrap();
728        let (headers, entity_tag) = match cached_headers {
729            CacheableResource::New { entity_tag, data } => (data, entity_tag),
730            CacheableResource::NotModified => unreachable!("expecting new headers"),
731        };
732        let token = headers.get(AUTHORIZATION).unwrap();
733        assert_eq!(headers.len(), 1, "{headers:?}");
734        assert_eq!(token, HeaderValue::from_static("Bearer test-token"));
735        assert!(token.is_sensitive());
736
737        extensions.insert(entity_tag);
738
739        let cached_headers = mdsc.headers(extensions).await?;
740
741        match cached_headers {
742            CacheableResource::New { .. } => unreachable!("expecting new headers"),
743            CacheableResource::NotModified => CacheableResource::<HeaderMap>::NotModified,
744        };
745        Ok(())
746    }
747
748    #[tokio::test]
749    async fn headers_failure() {
750        let mut mock = MockTokenProvider::new();
751        mock.expect_token()
752            .times(1)
753            .return_once(|| Err(errors::non_retryable_from_str("fail")));
754
755        let mdsc = MDSCredentials {
756            quota_project_id: None,
757            token_provider: TokenCache::new(mock),
758        };
759        assert!(mdsc.headers(Extensions::new()).await.is_err());
760    }
761
762    #[test]
763    fn error_message_with_adc() {
764        let provider = MDSAccessTokenProvider::builder()
765            .endpoint("http://127.0.0.1")
766            .created_by_adc(true)
767            .endpoint_overridden(false)
768            .build();
769
770        let want = MDS_NOT_FOUND_ERROR;
771        let got = provider.error_message();
772        assert!(got.contains(want), "{got}, {provider:?}");
773    }
774
775    #[test_case(false, false)]
776    #[test_case(false, true)]
777    #[test_case(true, true)]
778    fn error_message_without_adc(adc: bool, overridden: bool) {
779        let provider = MDSAccessTokenProvider::builder()
780            .endpoint("http://127.0.0.1")
781            .created_by_adc(adc)
782            .endpoint_overridden(overridden)
783            .build();
784
785        let not_want = MDS_NOT_FOUND_ERROR;
786        let got = provider.error_message();
787        assert!(!got.contains(not_want), "{got}, {provider:?}");
788    }
789
790    #[tokio::test]
791    #[serial]
792    async fn adc_no_mds() -> TestResult {
793        let Err(err) = Builder::from_adc().build_token_provider().token().await else {
794            // The environment has an MDS, skip the test.
795            return Ok(());
796        };
797
798        let original_err = find_source_error::<CredentialsError>(&err).unwrap();
799        assert!(
800            original_err.to_string().contains("application-default"),
801            "display={err}, debug={err:?}"
802        );
803
804        Ok(())
805    }
806
807    #[tokio::test]
808    #[serial]
809    async fn adc_overridden_mds() -> TestResult {
810        let _e = ScopedEnv::set(super::GCE_METADATA_HOST_ENV_VAR, "metadata.overridden");
811
812        let err = Builder::from_adc()
813            .build_token_provider()
814            .token()
815            .await
816            .unwrap_err();
817
818        let _e = ScopedEnv::remove(super::GCE_METADATA_HOST_ENV_VAR);
819
820        let original_err = find_source_error::<CredentialsError>(&err).unwrap();
821        assert!(original_err.is_transient());
822        assert!(
823            !original_err.to_string().contains("application-default"),
824            "display={err}, debug={err:?}"
825        );
826        let source = find_source_error::<reqwest::Error>(&err);
827        assert!(matches!(source, Some(e) if e.status().is_none()), "{err:?}");
828
829        Ok(())
830    }
831
832    #[tokio::test]
833    #[serial]
834    async fn builder_no_mds() -> TestResult {
835        let Err(e) = Builder::default().build_token_provider().token().await else {
836            // The environment has an MDS, skip the test.
837            return Ok(());
838        };
839
840        let original_err = find_source_error::<CredentialsError>(&e).unwrap();
841        assert!(
842            !format!("{:?}", original_err.source()).contains("application-default"),
843            "{e:?}"
844        );
845
846        Ok(())
847    }
848
849    #[tokio::test]
850    #[serial]
851    async fn test_gce_metadata_host_env_var() -> TestResult {
852        let server = Server::run();
853        let scopes = ["scope1", "scope2"];
854        let response = MDSTokenResponse {
855            access_token: "test-access-token".to_string(),
856            expires_in: Some(3600),
857            token_type: "test-token-type".to_string(),
858        };
859        server.expect(
860            Expectation::matching(all_of![
861                request::path(format!("{MDS_DEFAULT_URI}/token")),
862                request::query(url_decoded(contains(("scopes", scopes.join(",")))))
863            ])
864            .respond_with(json_encoded(response)),
865        );
866
867        let addr = server.addr().to_string();
868        let _e = ScopedEnv::set(super::GCE_METADATA_HOST_ENV_VAR, &addr);
869        let mdsc = Builder::default()
870            .with_scopes(["scope1", "scope2"])
871            .build()
872            .unwrap();
873        let headers = mdsc.headers(Extensions::new()).await.unwrap();
874        let _e = ScopedEnv::remove(super::GCE_METADATA_HOST_ENV_VAR);
875
876        assert_eq!(
877            get_token_from_headers(headers).unwrap(),
878            "test-access-token"
879        );
880        Ok(())
881    }
882
883    #[tokio::test]
884    #[parallel]
885    async fn headers_success_with_quota_project() -> TestResult {
886        let server = Server::run();
887        let scopes = ["scope1", "scope2"];
888        let response = MDSTokenResponse {
889            access_token: "test-access-token".to_string(),
890            expires_in: Some(3600),
891            token_type: "test-token-type".to_string(),
892        };
893        server.expect(
894            Expectation::matching(all_of![
895                request::path(format!("{MDS_DEFAULT_URI}/token")),
896                request::query(url_decoded(contains(("scopes", scopes.join(",")))))
897            ])
898            .respond_with(json_encoded(response)),
899        );
900
901        let mdsc = Builder::default()
902            .with_scopes(["scope1", "scope2"])
903            .with_endpoint(format!("http://{}", server.addr()))
904            .with_quota_project_id("test-project")
905            .build()?;
906
907        let headers = get_headers_from_cache(mdsc.headers(Extensions::new()).await.unwrap())?;
908        let token = headers.get(AUTHORIZATION).unwrap();
909        let quota_project = headers.get(QUOTA_PROJECT_KEY).unwrap();
910
911        assert_eq!(headers.len(), 2, "{headers:?}");
912        assert_eq!(
913            token,
914            HeaderValue::from_static("test-token-type test-access-token")
915        );
916        assert!(token.is_sensitive());
917        assert_eq!(quota_project, HeaderValue::from_static("test-project"));
918        assert!(!quota_project.is_sensitive());
919
920        Ok(())
921    }
922
923    #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
924    #[parallel]
925    async fn token_caching() -> TestResult {
926        let mut server = Server::run();
927        let scopes = vec!["scope1".to_string()];
928        let response = MDSTokenResponse {
929            access_token: "test-access-token".to_string(),
930            expires_in: Some(3600),
931            token_type: "test-token-type".to_string(),
932        };
933        server.expect(
934            Expectation::matching(all_of![
935                request::path(format!("{MDS_DEFAULT_URI}/token")),
936                request::query(url_decoded(contains(("scopes", scopes.join(",")))))
937            ])
938            .times(1)
939            .respond_with(json_encoded(response)),
940        );
941
942        let mdsc = Builder::default()
943            .with_scopes(scopes)
944            .with_endpoint(format!("http://{}", server.addr()))
945            .build()?;
946        let headers = mdsc.headers(Extensions::new()).await?;
947        assert_eq!(
948            get_token_from_headers(headers).unwrap(),
949            "test-access-token"
950        );
951        let headers = mdsc.headers(Extensions::new()).await?;
952        assert_eq!(
953            get_token_from_headers(headers).unwrap(),
954            "test-access-token"
955        );
956
957        // validate that the inner token provider is called only once
958        server.verify_and_clear();
959
960        Ok(())
961    }
962
963    #[tokio::test(start_paused = true)]
964    #[parallel]
965    async fn token_provider_full() -> TestResult {
966        let server = Server::run();
967        let scopes = vec!["scope1".to_string()];
968        let response = MDSTokenResponse {
969            access_token: "test-access-token".to_string(),
970            expires_in: Some(3600),
971            token_type: "test-token-type".to_string(),
972        };
973        server.expect(
974            Expectation::matching(all_of![
975                request::path(format!("{MDS_DEFAULT_URI}/token")),
976                request::query(url_decoded(contains(("scopes", scopes.join(",")))))
977            ])
978            .respond_with(json_encoded(response)),
979        );
980
981        let token = Builder::default()
982            .with_endpoint(format!("http://{}", server.addr()))
983            .with_scopes(scopes)
984            .build_token_provider()
985            .token()
986            .await?;
987
988        let now = tokio::time::Instant::now();
989        assert_eq!(token.token, "test-access-token");
990        assert_eq!(token.token_type, "test-token-type");
991        assert!(
992            token
993                .expires_at
994                .is_some_and(|d| d >= now + Duration::from_secs(3600))
995        );
996
997        Ok(())
998    }
999
1000    #[tokio::test(start_paused = true)]
1001    #[parallel]
1002    async fn token_provider_full_no_scopes() -> TestResult {
1003        let server = Server::run();
1004        let response = MDSTokenResponse {
1005            access_token: "test-access-token".to_string(),
1006            expires_in: Some(3600),
1007            token_type: "test-token-type".to_string(),
1008        };
1009        server.expect(
1010            Expectation::matching(request::path(format!("{MDS_DEFAULT_URI}/token")))
1011                .respond_with(json_encoded(response)),
1012        );
1013
1014        let token = Builder::default()
1015            .with_endpoint(format!("http://{}", server.addr()))
1016            .build_token_provider()
1017            .token()
1018            .await?;
1019
1020        let now = Instant::now();
1021        assert_eq!(token.token, "test-access-token");
1022        assert_eq!(token.token_type, "test-token-type");
1023        assert!(
1024            token
1025                .expires_at
1026                .is_some_and(|d| d == now + Duration::from_secs(3600))
1027        );
1028
1029        Ok(())
1030    }
1031
1032    #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
1033    #[parallel]
1034    async fn credential_provider_full() -> TestResult {
1035        let server = Server::run();
1036        let scopes = vec!["scope1".to_string()];
1037        let response = MDSTokenResponse {
1038            access_token: "test-access-token".to_string(),
1039            expires_in: None,
1040            token_type: "test-token-type".to_string(),
1041        };
1042        server.expect(
1043            Expectation::matching(all_of![
1044                request::path(format!("{MDS_DEFAULT_URI}/token")),
1045                request::query(url_decoded(contains(("scopes", scopes.join(",")))))
1046            ])
1047            .respond_with(json_encoded(response)),
1048        );
1049
1050        let mdsc = Builder::default()
1051            .with_endpoint(format!("http://{}", server.addr()))
1052            .with_scopes(scopes)
1053            .build()?;
1054        let headers = mdsc.headers(Extensions::new()).await?;
1055        assert_eq!(
1056            get_token_from_headers(headers.clone()).unwrap(),
1057            "test-access-token"
1058        );
1059        assert_eq!(
1060            get_token_type_from_headers(headers).unwrap(),
1061            "test-token-type"
1062        );
1063
1064        Ok(())
1065    }
1066
1067    #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
1068    #[parallel]
1069    async fn credentials_headers_retryable_error() -> TestResult {
1070        let server = Server::run();
1071        let scopes = vec!["scope1".to_string()];
1072        server.expect(
1073            Expectation::matching(all_of![
1074                request::path(format!("{MDS_DEFAULT_URI}/token")),
1075                request::query(url_decoded(contains(("scopes", scopes.join(",")))))
1076            ])
1077            .respond_with(status_code(503)),
1078        );
1079
1080        let mdsc = Builder::default()
1081            .with_endpoint(format!("http://{}", server.addr()))
1082            .with_scopes(scopes)
1083            .build()?;
1084        let err = mdsc.headers(Extensions::new()).await.unwrap_err();
1085        let original_err = find_source_error::<CredentialsError>(&err).unwrap();
1086        assert!(original_err.is_transient());
1087        let source = find_source_error::<reqwest::Error>(&err);
1088        assert!(
1089            matches!(source, Some(e) if e.status() == Some(StatusCode::SERVICE_UNAVAILABLE)),
1090            "{err:?}"
1091        );
1092
1093        Ok(())
1094    }
1095
1096    #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
1097    #[parallel]
1098    async fn credentials_headers_nonretryable_error() -> TestResult {
1099        let server = Server::run();
1100        let scopes = vec!["scope1".to_string()];
1101        server.expect(
1102            Expectation::matching(all_of![
1103                request::path(format!("{MDS_DEFAULT_URI}/token")),
1104                request::query(url_decoded(contains(("scopes", scopes.join(",")))))
1105            ])
1106            .respond_with(status_code(401)),
1107        );
1108
1109        let mdsc = Builder::default()
1110            .with_endpoint(format!("http://{}", server.addr()))
1111            .with_scopes(scopes)
1112            .build()?;
1113
1114        let err = mdsc.headers(Extensions::new()).await.unwrap_err();
1115        let original_err = find_source_error::<CredentialsError>(&err).unwrap();
1116        assert!(!original_err.is_transient());
1117        let source = find_source_error::<reqwest::Error>(&err);
1118        assert!(
1119            matches!(source, Some(e) if e.status() == Some(StatusCode::UNAUTHORIZED)),
1120            "{err:?}"
1121        );
1122
1123        Ok(())
1124    }
1125
1126    #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
1127    #[parallel]
1128    async fn credentials_headers_malformed_response_is_nonretryable() -> TestResult {
1129        let server = Server::run();
1130        let scopes = vec!["scope1".to_string()];
1131        server.expect(
1132            Expectation::matching(all_of![
1133                request::path(format!("{MDS_DEFAULT_URI}/token")),
1134                request::query(url_decoded(contains(("scopes", scopes.join(",")))))
1135            ])
1136            .respond_with(json_encoded("bad json")),
1137        );
1138
1139        let mdsc = Builder::default()
1140            .with_endpoint(format!("http://{}", server.addr()))
1141            .with_scopes(scopes)
1142            .build()?;
1143
1144        let e = mdsc.headers(Extensions::new()).await.err().unwrap();
1145        assert!(!e.is_transient());
1146
1147        Ok(())
1148    }
1149
1150    #[tokio::test]
1151    async fn get_default_universe_domain_success() -> TestResult {
1152        let universe_domain_response = Builder::default().build()?.universe_domain().await.unwrap();
1153        assert_eq!(universe_domain_response, DEFAULT_UNIVERSE_DOMAIN);
1154        Ok(())
1155    }
1156}
1157
1158#[cfg(all(test, google_cloud_unstable_id_token))]
1159mod unstable_tests {
1160    use super::idtoken;
1161    use super::*;
1162    use crate::credentials::idtoken::tests::generate_test_id_token;
1163    use crate::credentials::tests::find_source_error;
1164    use httptest::matchers::{all_of, contains, request, url_decoded};
1165    use httptest::responders::status_code;
1166    use httptest::{Expectation, Server};
1167    use reqwest::StatusCode;
1168    use scoped_env::ScopedEnv;
1169    use serial_test::{parallel, serial};
1170
1171    type TestResult = anyhow::Result<()>;
1172
1173    #[tokio::test]
1174    #[parallel]
1175    async fn test_idtoken_builder_build() -> TestResult {
1176        let server = Server::run();
1177        let audience = "test-audience";
1178        let format = "format";
1179        let token_string = generate_test_id_token(audience);
1180        server.expect(
1181            Expectation::matching(all_of![
1182                request::path(format!("{MDS_DEFAULT_URI}/identity")),
1183                request::query(url_decoded(contains(("audience", audience)))),
1184                request::query(url_decoded(contains(("format", format)))),
1185                request::query(url_decoded(contains(("licenses", "TRUE"))))
1186            ])
1187            .respond_with(status_code(200).body(token_string.clone())),
1188        );
1189
1190        let creds = idtoken::Builder::new(audience)
1191            .with_endpoint(format!("http://{}", server.addr()))
1192            .with_format(format)
1193            .with_licenses(true)
1194            .build()?;
1195
1196        let id_token = creds.id_token().await?;
1197        assert_eq!(id_token, token_string);
1198        Ok(())
1199    }
1200
1201    #[tokio::test]
1202    #[serial]
1203    async fn test_idtoken_builder_build_with_env_var() -> TestResult {
1204        let server = Server::run();
1205        let audience = "test-audience";
1206        let token_string = generate_test_id_token(audience);
1207        server.expect(
1208            Expectation::matching(all_of![
1209                request::path(format!("{MDS_DEFAULT_URI}/identity")),
1210                request::query(url_decoded(contains(("audience", audience))))
1211            ])
1212            .respond_with(status_code(200).body(token_string.clone())),
1213        );
1214
1215        let addr = server.addr().to_string();
1216        let _e = ScopedEnv::set(super::GCE_METADATA_HOST_ENV_VAR, &addr);
1217
1218        let creds = idtoken::Builder::new(audience).build()?;
1219
1220        let id_token = creds.id_token().await?;
1221        assert_eq!(id_token, token_string);
1222
1223        let _e = ScopedEnv::remove(super::GCE_METADATA_HOST_ENV_VAR);
1224        Ok(())
1225    }
1226
1227    #[tokio::test]
1228    #[parallel]
1229    async fn test_idtoken_provider_http_error() -> TestResult {
1230        let server = Server::run();
1231        let audience = "test-audience";
1232        server.expect(
1233            Expectation::matching(all_of![
1234                request::path(format!("{MDS_DEFAULT_URI}/identity")),
1235                request::query(url_decoded(contains(("audience", audience))))
1236            ])
1237            .respond_with(status_code(503)),
1238        );
1239
1240        let creds = idtoken::Builder::new(audience)
1241            .with_endpoint(format!("http://{}", server.addr()))
1242            .build()?;
1243
1244        let err = creds.id_token().await.unwrap_err();
1245        let source = find_source_error::<reqwest::Error>(&err);
1246        assert!(
1247            matches!(source, Some(e) if e.status() == Some(StatusCode::SERVICE_UNAVAILABLE)),
1248            "{err:?}"
1249        );
1250        Ok(())
1251    }
1252
1253    #[tokio::test]
1254    #[parallel]
1255    async fn test_idtoken_caching() -> TestResult {
1256        let server = Server::run();
1257        let audience = "test-audience";
1258        let token_string = generate_test_id_token(audience);
1259        server.expect(
1260            Expectation::matching(all_of![
1261                request::path(format!("{MDS_DEFAULT_URI}/identity")),
1262                request::query(url_decoded(contains(("audience", audience))))
1263            ])
1264            .times(1)
1265            .respond_with(status_code(200).body(token_string.clone())),
1266        );
1267
1268        let creds = idtoken::Builder::new(audience)
1269            .with_endpoint(format!("http://{}", server.addr()))
1270            .build()?;
1271
1272        let id_token = creds.id_token().await?;
1273        assert_eq!(id_token, token_string);
1274
1275        let id_token = creds.id_token().await?;
1276        assert_eq!(id_token, token_string);
1277
1278        Ok(())
1279    }
1280}