Skip to main content

google_cloud_auth/credentials/
mds.rs

1// Copyright 2024 Google LLC
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     https://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15//! [Metadata Service] Credentials type.
16//!
17//! Google Cloud environments such as [Google Compute Engine (GCE)][gce-link],
18//! [Google Kubernetes Engine (GKE)][gke-link], or [Cloud Run] provide a metadata service.
19//! This is a local service to the VM (or pod) which (as the name implies) provides
20//! metadata information about the VM. The service also provides access
21//! tokens associated with the [default service account] for the corresponding
22//! VM.
23//!
24//! The default host name of the metadata service is `metadata.google.internal`.
25//! If you would like to use a different hostname, you can set it using the
26//! `GCE_METADATA_HOST` environment variable.
27//!
28//! You can use this access token to securely authenticate with Google Cloud,
29//! without having to download secrets or other credentials. The types in this
30//! module allow you to retrieve these access tokens, and can be used with
31//! the Google Cloud client libraries for Rust.
32//!
33//! ## Example: Creating credentials with a custom quota project
34//!
35//! ```
36//! # use google_cloud_auth::credentials::mds::Builder;
37//! # use google_cloud_auth::credentials::Credentials;
38//! # use http::Extensions;
39//! # tokio_test::block_on(async {
40//! let credentials: Credentials = Builder::default()
41//!     .with_quota_project_id("my-quota-project")
42//!     .build()?;
43//! let headers = credentials.headers(Extensions::new()).await?;
44//! println!("Headers: {headers:?}");
45//! # Ok::<(), anyhow::Error>(())
46//! # });
47//! ```
48//!
49//! ## Example: Creating credentials with custom retry behavior
50//!
51//! ```
52//! # use google_cloud_auth::credentials::mds::Builder;
53//! # use google_cloud_auth::credentials::Credentials;
54//! # use http::Extensions;
55//! # use std::time::Duration;
56//! # tokio_test::block_on(async {
57//! use gax::retry_policy::{AlwaysRetry, RetryPolicyExt};
58//! use gax::exponential_backoff::ExponentialBackoff;
59//! let backoff = ExponentialBackoff::default();
60//! let credentials: Credentials = Builder::default()
61//!     .with_retry_policy(AlwaysRetry.with_attempt_limit(3))
62//!     .with_backoff_policy(backoff)
63//!     .build()?;
64//! let headers = credentials.headers(Extensions::new()).await?;
65//! println!("Headers: {headers:?}");
66//! # Ok::<(), anyhow::Error>(())
67//! # });
68//! ```
69//!
70//! [Application Default Credentials]: https://cloud.google.com/docs/authentication/application-default-credentials
71//! [Cloud Run]: https://cloud.google.com/run
72//! [default service account]: https://cloud.google.com/iam/docs/service-account-types#default
73//! [gce-link]: https://cloud.google.com/products/compute
74//! [gke-link]: https://cloud.google.com/kubernetes-engine
75//! [Metadata Service]: https://cloud.google.com/compute/docs/metadata/overview
76
77use crate::credentials::dynamic::{AccessTokenCredentialsProvider, CredentialsProvider};
78use crate::credentials::{AccessToken, AccessTokenCredentials, CacheableResource, Credentials};
79use crate::headers_util::build_cacheable_headers;
80use crate::mds::client::Client as MDSClient;
81use crate::retry::{Builder as RetryTokenProviderBuilder, TokenProviderWithRetry};
82use crate::token::{CachedTokenProvider, Token, TokenProvider};
83use crate::token_cache::TokenCache;
84use crate::{BuildResult, Result};
85use async_trait::async_trait;
86use gax::backoff_policy::BackoffPolicyArg;
87use gax::error::CredentialsError;
88use gax::retry_policy::RetryPolicyArg;
89use gax::retry_throttler::RetryThrottlerArg;
90use http::{Extensions, HeaderMap};
91use std::default::Default;
92use std::sync::Arc;
93
94// TODO(#2235) - Improve this message by talking about retries when really running with MDS
95const MDS_NOT_FOUND_ERROR: &str = concat!(
96    "Could not fetch an auth token to authenticate with Google Cloud. ",
97    "The most common reason for this problem is that you are not running in a Google Cloud Environment ",
98    "and you have not configured local credentials for development and testing. ",
99    "To setup local credentials, run `gcloud auth application-default login`. ",
100    "More information on how to authenticate client libraries can be found at https://cloud.google.com/docs/authentication/client-libraries"
101);
102
103#[derive(Debug)]
104struct MDSCredentials<T>
105where
106    T: CachedTokenProvider,
107{
108    quota_project_id: Option<String>,
109    token_provider: T,
110}
111
112/// Creates [Credentials] instances backed by the [Metadata Service].
113///
114/// While the Google Cloud client libraries for Rust default to credentials
115/// backed by the metadata service, some applications may need to:
116/// * Customize the metadata service credentials in some way
117/// * Bypass the [Application Default Credentials] lookup and only
118///   use the metadata server credentials
119/// * Use the credentials directly outside the client libraries
120///
121/// [Application Default Credentials]: https://cloud.google.com/docs/authentication/application-default-credentials
122/// [Metadata Service]: https://cloud.google.com/compute/docs/metadata/overview
123#[derive(Debug, Default)]
124pub struct Builder {
125    endpoint: Option<String>,
126    quota_project_id: Option<String>,
127    scopes: Option<Vec<String>>,
128    created_by_adc: bool,
129    retry_builder: RetryTokenProviderBuilder,
130}
131
132impl Builder {
133    /// Sets the endpoint for this credentials.
134    ///
135    /// A trailing slash is significant, so specify the base URL without a trailing
136    /// slash. If not set, the credentials use `http://metadata.google.internal`.
137    ///
138    /// # Example
139    /// ```
140    /// # use google_cloud_auth::credentials::mds::Builder;
141    /// # tokio_test::block_on(async {
142    /// let credentials = Builder::default()
143    ///     .with_endpoint("https://metadata.google.foobar")
144    ///     .build();
145    /// # });
146    /// ```
147    pub fn with_endpoint<S: Into<String>>(mut self, endpoint: S) -> Self {
148        self.endpoint = Some(endpoint.into());
149        self
150    }
151
152    /// Set the [quota project] for this credentials.
153    ///
154    /// In some services, you can use a service account in
155    /// one project for authentication and authorization, and charge
156    /// the usage to a different project. This may require that the
157    /// service account has `serviceusage.services.use` permissions on the quota project.
158    ///
159    /// [quota project]: https://cloud.google.com/docs/quotas/quota-project
160    pub fn with_quota_project_id<S: Into<String>>(mut self, quota_project_id: S) -> Self {
161        self.quota_project_id = Some(quota_project_id.into());
162        self
163    }
164
165    /// Sets the [scopes] for this credentials.
166    ///
167    /// Metadata server issues tokens based on the requested scopes.
168    /// If no scopes are specified, the credentials defaults to all
169    /// scopes configured for the [default service account] on the instance.
170    ///
171    /// [default service account]: https://cloud.google.com/iam/docs/service-account-types#default
172    /// [scopes]: https://developers.google.com/identity/protocols/oauth2/scopes
173    pub fn with_scopes<I, S>(mut self, scopes: I) -> Self
174    where
175        I: IntoIterator<Item = S>,
176        S: Into<String>,
177    {
178        self.scopes = Some(scopes.into_iter().map(|s| s.into()).collect());
179        self
180    }
181
182    /// Configure the retry policy for fetching tokens.
183    ///
184    /// The retry policy controls how to handle retries, and sets limits on
185    /// the number of attempts or the total time spent retrying.
186    ///
187    /// ```
188    /// # use google_cloud_auth::credentials::mds::Builder;
189    /// # tokio_test::block_on(async {
190    /// use gax::retry_policy::{AlwaysRetry, RetryPolicyExt};
191    /// let credentials = Builder::default()
192    ///     .with_retry_policy(AlwaysRetry.with_attempt_limit(3))
193    ///     .build();
194    /// # });
195    /// ```
196    pub fn with_retry_policy<V: Into<RetryPolicyArg>>(mut self, v: V) -> Self {
197        self.retry_builder = self.retry_builder.with_retry_policy(v.into());
198        self
199    }
200
201    /// Configure the retry backoff policy.
202    ///
203    /// The backoff policy controls how long to wait in between retry attempts.
204    ///
205    /// ```
206    /// # use google_cloud_auth::credentials::mds::Builder;
207    /// # use std::time::Duration;
208    /// # tokio_test::block_on(async {
209    /// use gax::exponential_backoff::ExponentialBackoff;
210    /// let policy = ExponentialBackoff::default();
211    /// let credentials = Builder::default()
212    ///     .with_backoff_policy(policy)
213    ///     .build();
214    /// # });
215    /// ```
216    pub fn with_backoff_policy<V: Into<BackoffPolicyArg>>(mut self, v: V) -> Self {
217        self.retry_builder = self.retry_builder.with_backoff_policy(v.into());
218        self
219    }
220
221    /// Configure the retry throttler.
222    ///
223    /// Advanced applications may want to configure a retry throttler to
224    /// [Address Cascading Failures] and when [Handling Overload] conditions.
225    /// The authentication library throttles its retry loop, using a policy to
226    /// control the throttling algorithm. Use this method to fine tune or
227    /// customize the default retry throttler.
228    ///
229    /// [Handling Overload]: https://sre.google/sre-book/handling-overload/
230    /// [Address Cascading Failures]: https://sre.google/sre-book/addressing-cascading-failures/
231    ///
232    /// ```
233    /// # use google_cloud_auth::credentials::mds::Builder;
234    /// # tokio_test::block_on(async {
235    /// use gax::retry_throttler::AdaptiveThrottler;
236    /// let credentials = Builder::default()
237    ///     .with_retry_throttler(AdaptiveThrottler::default())
238    ///     .build();
239    /// # });
240    /// ```
241    pub fn with_retry_throttler<V: Into<RetryThrottlerArg>>(mut self, v: V) -> Self {
242        self.retry_builder = self.retry_builder.with_retry_throttler(v.into());
243        self
244    }
245
246    // This method is used to build mds credentials from ADC
247    pub(crate) fn from_adc() -> Self {
248        Self {
249            created_by_adc: true,
250            ..Default::default()
251        }
252    }
253
254    fn build_token_provider(self) -> TokenProviderWithRetry<MDSAccessTokenProvider> {
255        let tp = MDSAccessTokenProvider::builder()
256            .endpoint(self.endpoint)
257            .maybe_scopes(self.scopes)
258            .created_by_adc(self.created_by_adc)
259            .build();
260        self.retry_builder.build(tp)
261    }
262
263    /// Returns a [Credentials] instance with the configured settings.
264    pub fn build(self) -> BuildResult<Credentials> {
265        Ok(self.build_access_token_credentials()?.into())
266    }
267
268    /// Returns an [AccessTokenCredentials] instance with the configured settings.
269    ///
270    /// # Example
271    ///
272    /// ```
273    /// # use google_cloud_auth::credentials::mds::Builder;
274    /// # use google_cloud_auth::credentials::{AccessTokenCredentials, AccessTokenCredentialsProvider};
275    /// # tokio_test::block_on(async {
276    /// let credentials: AccessTokenCredentials = Builder::default()
277    ///     .with_quota_project_id("my-quota-project")
278    ///     .build_access_token_credentials()?;
279    /// let access_token = credentials.access_token().await?;
280    /// println!("Token: {}", access_token.token);
281    /// # Ok::<(), anyhow::Error>(())
282    /// # });
283    /// ```
284    pub fn build_access_token_credentials(self) -> BuildResult<AccessTokenCredentials> {
285        let mdsc = MDSCredentials {
286            quota_project_id: self.quota_project_id.clone(),
287            token_provider: TokenCache::new(self.build_token_provider()),
288        };
289        Ok(AccessTokenCredentials {
290            inner: Arc::new(mdsc),
291        })
292    }
293
294    /// Returns a [crate::signer::Signer] instance with the configured settings.
295    ///
296    /// The returned [crate::signer::Signer] uses the [IAM signBlob API] to sign content. This API
297    /// requires a network request for each signing operation.
298    ///
299    /// # Example
300    ///
301    /// ```
302    /// # use google_cloud_auth::credentials::mds::Builder;
303    /// # use google_cloud_auth::signer::Signer;
304    /// # tokio_test::block_on(async {
305    /// let signer: Signer = Builder::default().build_signer()?;
306    /// # Ok::<(), anyhow::Error>(())
307    /// # });
308    /// ```
309    ///
310    /// [IAM signBlob API]: https://cloud.google.com/iam/docs/reference/credentials/rest/v1/projects.serviceAccounts/signBlob
311    pub fn build_signer(self) -> BuildResult<crate::signer::Signer> {
312        self.build_signer_with_iam_endpoint_override(None)
313    }
314
315    // only used for testing
316    fn build_signer_with_iam_endpoint_override(
317        self,
318        iam_endpoint: Option<String>,
319    ) -> BuildResult<crate::signer::Signer> {
320        let client = MDSClient::new(self.endpoint.clone());
321        let credentials = self.build()?;
322        let signing_provider = crate::signer::mds::MDSSigner::new(client, credentials);
323        let signing_provider = iam_endpoint
324            .iter()
325            .fold(signing_provider, |signing_provider, endpoint| {
326                signing_provider.with_iam_endpoint_override(endpoint)
327            });
328        Ok(crate::signer::Signer {
329            inner: Arc::new(signing_provider),
330        })
331    }
332}
333
334#[async_trait::async_trait]
335impl<T> CredentialsProvider for MDSCredentials<T>
336where
337    T: CachedTokenProvider,
338{
339    async fn headers(&self, extensions: Extensions) -> Result<CacheableResource<HeaderMap>> {
340        let cached_token = self.token_provider.token(extensions).await?;
341        build_cacheable_headers(&cached_token, &self.quota_project_id)
342    }
343}
344
345#[async_trait::async_trait]
346impl<T> AccessTokenCredentialsProvider for MDSCredentials<T>
347where
348    T: CachedTokenProvider,
349{
350    async fn access_token(&self) -> Result<AccessToken> {
351        let token = self.token_provider.token(Extensions::new()).await?;
352        token.into()
353    }
354}
355
356#[derive(Debug, Default)]
357struct MDSAccessTokenProviderBuilder {
358    scopes: Option<Vec<String>>,
359    endpoint: Option<String>,
360    created_by_adc: bool,
361}
362
363impl MDSAccessTokenProviderBuilder {
364    fn build(self) -> MDSAccessTokenProvider {
365        MDSAccessTokenProvider {
366            client: MDSClient::new(self.endpoint),
367            scopes: self.scopes,
368            created_by_adc: self.created_by_adc,
369        }
370    }
371
372    fn maybe_scopes(mut self, v: Option<Vec<String>>) -> Self {
373        self.scopes = v;
374        self
375    }
376
377    fn endpoint<T>(mut self, v: Option<T>) -> Self
378    where
379        T: Into<String>,
380    {
381        self.endpoint = v.map(Into::into);
382        self
383    }
384
385    fn created_by_adc(mut self, v: bool) -> Self {
386        self.created_by_adc = v;
387        self
388    }
389}
390
391#[derive(Debug, Clone)]
392struct MDSAccessTokenProvider {
393    scopes: Option<Vec<String>>,
394    client: MDSClient,
395    created_by_adc: bool,
396}
397
398impl MDSAccessTokenProvider {
399    fn builder() -> MDSAccessTokenProviderBuilder {
400        MDSAccessTokenProviderBuilder::default()
401    }
402
403    // During ADC, if no credentials are found in the well-known location and the GOOGLE_APPLICATION_CREDENTIALS
404    // environment variable is not set, we default to MDS credentials without checking if the code is really
405    // running in an environment with MDS. To help users who got to this state because of lack of credentials
406    // setup on their machines, we provide a detailed error message to them talking about local setup and other
407    // auth mechanisms available to them.
408    // If the endpoint is overridden, even if ADC was used to create the MDS credentials, we do not give a detailed
409    // error message because they deliberately wanted to use an MDS.
410    fn error_message(&self) -> &str {
411        if self.use_adc_message() {
412            MDS_NOT_FOUND_ERROR
413        } else {
414            "failed to fetch token"
415        }
416    }
417
418    fn use_adc_message(&self) -> bool {
419        self.created_by_adc && self.client.is_default_endpoint
420    }
421}
422
423#[async_trait]
424impl TokenProvider for MDSAccessTokenProvider {
425    async fn token(&self) -> Result<Token> {
426        self.client
427            .access_token(self.scopes.clone())
428            .await
429            .map_err(|e| CredentialsError::new(e.is_transient(), self.error_message(), e))
430    }
431}
432
433#[cfg(test)]
434mod tests {
435    use super::*;
436    use crate::credentials::DEFAULT_UNIVERSE_DOMAIN;
437    use crate::credentials::QUOTA_PROJECT_KEY;
438    use crate::credentials::tests::{
439        find_source_error, get_headers_from_cache, get_mock_auth_retry_policy,
440        get_mock_backoff_policy, get_mock_retry_throttler, get_token_from_headers,
441        get_token_type_from_headers,
442    };
443    use crate::errors;
444    use crate::errors::CredentialsError;
445    use crate::mds::client::MDSTokenResponse;
446    use crate::mds::{GCE_METADATA_HOST_ENV_VAR, MDS_DEFAULT_URI, METADATA_ROOT};
447    use crate::token::tests::MockTokenProvider;
448    use http::HeaderValue;
449    use http::header::AUTHORIZATION;
450    use httptest::cycle;
451    use httptest::matchers::{all_of, contains, request, url_decoded};
452    use httptest::responders::{json_encoded, status_code};
453    use httptest::{Expectation, Server};
454    use reqwest::StatusCode;
455    use scoped_env::ScopedEnv;
456    use serial_test::{parallel, serial};
457    use std::error::Error;
458    use std::time::Duration;
459    use test_case::test_case;
460    use tokio::time::Instant;
461    use url::Url;
462
463    type TestResult = anyhow::Result<()>;
464
465    #[tokio::test]
466    #[parallel]
467    async fn test_mds_retries_on_transient_failures() -> TestResult {
468        let mut server = Server::run();
469        server.expect(
470            Expectation::matching(request::path(format!("{MDS_DEFAULT_URI}/token")))
471                .times(3)
472                .respond_with(status_code(503)),
473        );
474
475        let provider = Builder::default()
476            .with_endpoint(format!("http://{}", server.addr()))
477            .with_retry_policy(get_mock_auth_retry_policy(3))
478            .with_backoff_policy(get_mock_backoff_policy())
479            .with_retry_throttler(get_mock_retry_throttler())
480            .build_token_provider();
481
482        let err = provider.token().await.unwrap_err();
483        assert!(!err.is_transient());
484        server.verify_and_clear();
485        Ok(())
486    }
487
488    #[tokio::test]
489    #[parallel]
490    async fn test_mds_does_not_retry_on_non_transient_failures() -> TestResult {
491        let mut server = Server::run();
492        server.expect(
493            Expectation::matching(request::path(format!("{MDS_DEFAULT_URI}/token")))
494                .times(1)
495                .respond_with(status_code(401)),
496        );
497
498        let provider = Builder::default()
499            .with_endpoint(format!("http://{}", server.addr()))
500            .with_retry_policy(get_mock_auth_retry_policy(1))
501            .with_backoff_policy(get_mock_backoff_policy())
502            .with_retry_throttler(get_mock_retry_throttler())
503            .build_token_provider();
504
505        let err = provider.token().await.unwrap_err();
506        assert!(!err.is_transient());
507        server.verify_and_clear();
508        Ok(())
509    }
510
511    #[tokio::test]
512    #[parallel]
513    async fn test_mds_retries_for_success() -> TestResult {
514        let mut server = Server::run();
515        let response = MDSTokenResponse {
516            access_token: "test-access-token".to_string(),
517            expires_in: Some(3600),
518            token_type: "test-token-type".to_string(),
519        };
520
521        server.expect(
522            Expectation::matching(request::path(format!("{MDS_DEFAULT_URI}/token")))
523                .times(3)
524                .respond_with(cycle![
525                    status_code(503).body("try-again"),
526                    status_code(503).body("try-again"),
527                    status_code(200)
528                        .append_header("Content-Type", "application/json")
529                        .body(serde_json::to_string(&response).unwrap()),
530                ]),
531        );
532
533        let provider = Builder::default()
534            .with_endpoint(format!("http://{}", server.addr()))
535            .with_retry_policy(get_mock_auth_retry_policy(3))
536            .with_backoff_policy(get_mock_backoff_policy())
537            .with_retry_throttler(get_mock_retry_throttler())
538            .build_token_provider();
539
540        let token = provider.token().await?;
541        assert_eq!(token.token, "test-access-token");
542
543        server.verify_and_clear();
544        Ok(())
545    }
546
547    #[test]
548    #[parallel]
549    fn validate_default_endpoint_urls() {
550        let default_endpoint_address = Url::parse(&format!("{METADATA_ROOT}{MDS_DEFAULT_URI}"));
551        assert!(default_endpoint_address.is_ok());
552
553        let token_endpoint_address = Url::parse(&format!("{METADATA_ROOT}{MDS_DEFAULT_URI}/token"));
554        assert!(token_endpoint_address.is_ok());
555    }
556
557    #[tokio::test]
558    #[parallel]
559    async fn headers_success() -> TestResult {
560        let token = Token {
561            token: "test-token".to_string(),
562            token_type: "Bearer".to_string(),
563            expires_at: None,
564            metadata: None,
565        };
566
567        let mut mock = MockTokenProvider::new();
568        mock.expect_token().times(1).return_once(|| Ok(token));
569
570        let mdsc = MDSCredentials {
571            quota_project_id: None,
572            token_provider: TokenCache::new(mock),
573        };
574
575        let mut extensions = Extensions::new();
576        let cached_headers = mdsc.headers(extensions.clone()).await.unwrap();
577        let (headers, entity_tag) = match cached_headers {
578            CacheableResource::New { entity_tag, data } => (data, entity_tag),
579            CacheableResource::NotModified => unreachable!("expecting new headers"),
580        };
581        let token = headers.get(AUTHORIZATION).unwrap();
582        assert_eq!(headers.len(), 1, "{headers:?}");
583        assert_eq!(token, HeaderValue::from_static("Bearer test-token"));
584        assert!(token.is_sensitive());
585
586        extensions.insert(entity_tag);
587
588        let cached_headers = mdsc.headers(extensions).await?;
589
590        match cached_headers {
591            CacheableResource::New { .. } => unreachable!("expecting new headers"),
592            CacheableResource::NotModified => CacheableResource::<HeaderMap>::NotModified,
593        };
594        Ok(())
595    }
596
597    #[tokio::test]
598    #[parallel]
599    async fn access_token_success() -> TestResult {
600        let server = Server::run();
601        let response = MDSTokenResponse {
602            access_token: "test-access-token".to_string(),
603            expires_in: Some(3600),
604            token_type: "Bearer".to_string(),
605        };
606        server.expect(
607            Expectation::matching(all_of![request::path(format!("{MDS_DEFAULT_URI}/token")),])
608                .respond_with(json_encoded(response)),
609        );
610
611        let creds = Builder::default()
612            .with_endpoint(format!("http://{}", server.addr()))
613            .build_access_token_credentials()
614            .unwrap();
615
616        let access_token = creds.access_token().await.unwrap();
617        assert_eq!(access_token.token, "test-access-token");
618
619        Ok(())
620    }
621
622    #[tokio::test]
623    #[parallel]
624    async fn headers_failure() {
625        let mut mock = MockTokenProvider::new();
626        mock.expect_token()
627            .times(1)
628            .return_once(|| Err(errors::non_retryable_from_str("fail")));
629
630        let mdsc = MDSCredentials {
631            quota_project_id: None,
632            token_provider: TokenCache::new(mock),
633        };
634        assert!(mdsc.headers(Extensions::new()).await.is_err());
635    }
636
637    #[test]
638    #[parallel]
639    fn error_message_with_adc() {
640        let provider = MDSAccessTokenProvider::builder()
641            .created_by_adc(true)
642            .build();
643
644        let want = MDS_NOT_FOUND_ERROR;
645        let got = provider.error_message();
646        assert!(got.contains(want), "{got}, {provider:?}");
647    }
648
649    #[test_case(false, false)]
650    #[test_case(false, true)]
651    #[test_case(true, true)]
652    fn error_message_without_adc(adc: bool, overridden: bool) {
653        let endpoint = if overridden {
654            Some("http://127.0.0.1")
655        } else {
656            None
657        };
658        let provider = MDSAccessTokenProvider::builder()
659            .endpoint(endpoint)
660            .created_by_adc(adc)
661            .build();
662
663        let not_want = MDS_NOT_FOUND_ERROR;
664        let got = provider.error_message();
665        assert!(!got.contains(not_want), "{got}, {provider:?}");
666    }
667
668    #[tokio::test]
669    #[serial]
670    async fn adc_no_mds() -> TestResult {
671        let Err(err) = Builder::from_adc().build_token_provider().token().await else {
672            // The environment has an MDS, skip the test.
673            return Ok(());
674        };
675
676        let original_err = find_source_error::<CredentialsError>(&err).unwrap();
677        assert!(
678            original_err.to_string().contains("application-default"),
679            "display={err}, debug={err:?}"
680        );
681
682        Ok(())
683    }
684
685    #[tokio::test]
686    #[serial]
687    async fn adc_overridden_mds() -> TestResult {
688        let _e = ScopedEnv::set(GCE_METADATA_HOST_ENV_VAR, "metadata.overridden");
689
690        let err = Builder::from_adc()
691            .build_token_provider()
692            .token()
693            .await
694            .unwrap_err();
695
696        let _e = ScopedEnv::remove(GCE_METADATA_HOST_ENV_VAR);
697
698        let original_err = find_source_error::<CredentialsError>(&err).unwrap();
699        assert!(original_err.is_transient());
700        assert!(
701            !original_err.to_string().contains("application-default"),
702            "display={err}, debug={err:?}"
703        );
704        let source = find_source_error::<reqwest::Error>(&err);
705        assert!(matches!(source, Some(e) if e.status().is_none()), "{err:?}");
706
707        Ok(())
708    }
709
710    #[tokio::test]
711    #[serial]
712    async fn builder_no_mds() -> TestResult {
713        let Err(e) = Builder::default().build_token_provider().token().await else {
714            // The environment has an MDS, skip the test.
715            return Ok(());
716        };
717
718        let original_err = find_source_error::<CredentialsError>(&e).unwrap();
719        assert!(
720            !format!("{:?}", original_err.source()).contains("application-default"),
721            "{e:?}"
722        );
723
724        Ok(())
725    }
726
727    #[tokio::test]
728    #[serial]
729    async fn test_gce_metadata_host_env_var() -> TestResult {
730        let server = Server::run();
731        let scopes = ["scope1", "scope2"];
732        let response = MDSTokenResponse {
733            access_token: "test-access-token".to_string(),
734            expires_in: Some(3600),
735            token_type: "test-token-type".to_string(),
736        };
737        server.expect(
738            Expectation::matching(all_of![
739                request::path(format!("{MDS_DEFAULT_URI}/token")),
740                request::query(url_decoded(contains(("scopes", scopes.join(",")))))
741            ])
742            .respond_with(json_encoded(response)),
743        );
744
745        let addr = server.addr().to_string();
746        let _e = ScopedEnv::set(GCE_METADATA_HOST_ENV_VAR, &addr);
747        let mdsc = Builder::default()
748            .with_scopes(["scope1", "scope2"])
749            .build()
750            .unwrap();
751        let headers = mdsc.headers(Extensions::new()).await.unwrap();
752        let _e = ScopedEnv::remove(GCE_METADATA_HOST_ENV_VAR);
753
754        assert_eq!(
755            get_token_from_headers(headers).unwrap(),
756            "test-access-token"
757        );
758        Ok(())
759    }
760
761    #[tokio::test]
762    #[parallel]
763    async fn headers_success_with_quota_project() -> TestResult {
764        let server = Server::run();
765        let scopes = ["scope1", "scope2"];
766        let response = MDSTokenResponse {
767            access_token: "test-access-token".to_string(),
768            expires_in: Some(3600),
769            token_type: "test-token-type".to_string(),
770        };
771        server.expect(
772            Expectation::matching(all_of![
773                request::path(format!("{MDS_DEFAULT_URI}/token")),
774                request::query(url_decoded(contains(("scopes", scopes.join(",")))))
775            ])
776            .respond_with(json_encoded(response)),
777        );
778
779        let mdsc = Builder::default()
780            .with_scopes(["scope1", "scope2"])
781            .with_endpoint(format!("http://{}", server.addr()))
782            .with_quota_project_id("test-project")
783            .build()?;
784
785        let headers = get_headers_from_cache(mdsc.headers(Extensions::new()).await.unwrap())?;
786        let token = headers.get(AUTHORIZATION).unwrap();
787        let quota_project = headers.get(QUOTA_PROJECT_KEY).unwrap();
788
789        assert_eq!(headers.len(), 2, "{headers:?}");
790        assert_eq!(
791            token,
792            HeaderValue::from_static("test-token-type test-access-token")
793        );
794        assert!(token.is_sensitive());
795        assert_eq!(quota_project, HeaderValue::from_static("test-project"));
796        assert!(!quota_project.is_sensitive());
797
798        Ok(())
799    }
800
801    #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
802    #[parallel]
803    async fn token_caching() -> TestResult {
804        let mut server = Server::run();
805        let scopes = vec!["scope1".to_string()];
806        let response = MDSTokenResponse {
807            access_token: "test-access-token".to_string(),
808            expires_in: Some(3600),
809            token_type: "test-token-type".to_string(),
810        };
811        server.expect(
812            Expectation::matching(all_of![
813                request::path(format!("{MDS_DEFAULT_URI}/token")),
814                request::query(url_decoded(contains(("scopes", scopes.join(",")))))
815            ])
816            .times(1)
817            .respond_with(json_encoded(response)),
818        );
819
820        let mdsc = Builder::default()
821            .with_scopes(scopes)
822            .with_endpoint(format!("http://{}", server.addr()))
823            .build()?;
824        let headers = mdsc.headers(Extensions::new()).await?;
825        assert_eq!(
826            get_token_from_headers(headers).unwrap(),
827            "test-access-token"
828        );
829        let headers = mdsc.headers(Extensions::new()).await?;
830        assert_eq!(
831            get_token_from_headers(headers).unwrap(),
832            "test-access-token"
833        );
834
835        // validate that the inner token provider is called only once
836        server.verify_and_clear();
837
838        Ok(())
839    }
840
841    #[tokio::test(start_paused = true)]
842    #[parallel]
843    async fn token_provider_full() -> TestResult {
844        let server = Server::run();
845        let scopes = vec!["scope1".to_string()];
846        let response = MDSTokenResponse {
847            access_token: "test-access-token".to_string(),
848            expires_in: Some(3600),
849            token_type: "test-token-type".to_string(),
850        };
851        server.expect(
852            Expectation::matching(all_of![
853                request::path(format!("{MDS_DEFAULT_URI}/token")),
854                request::query(url_decoded(contains(("scopes", scopes.join(",")))))
855            ])
856            .respond_with(json_encoded(response)),
857        );
858
859        let token = Builder::default()
860            .with_endpoint(format!("http://{}", server.addr()))
861            .with_scopes(scopes)
862            .build_token_provider()
863            .token()
864            .await?;
865
866        let now = tokio::time::Instant::now();
867        assert_eq!(token.token, "test-access-token");
868        assert_eq!(token.token_type, "test-token-type");
869        assert!(
870            token
871                .expires_at
872                .is_some_and(|d| d >= now + Duration::from_secs(3600))
873        );
874
875        Ok(())
876    }
877
878    #[tokio::test(start_paused = true)]
879    #[parallel]
880    async fn token_provider_full_no_scopes() -> TestResult {
881        let server = Server::run();
882        let response = MDSTokenResponse {
883            access_token: "test-access-token".to_string(),
884            expires_in: Some(3600),
885            token_type: "test-token-type".to_string(),
886        };
887        server.expect(
888            Expectation::matching(request::path(format!("{MDS_DEFAULT_URI}/token")))
889                .respond_with(json_encoded(response)),
890        );
891
892        let token = Builder::default()
893            .with_endpoint(format!("http://{}", server.addr()))
894            .build_token_provider()
895            .token()
896            .await?;
897
898        let now = Instant::now();
899        assert_eq!(token.token, "test-access-token");
900        assert_eq!(token.token_type, "test-token-type");
901        assert!(
902            token
903                .expires_at
904                .is_some_and(|d| d == now + Duration::from_secs(3600))
905        );
906
907        Ok(())
908    }
909
910    #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
911    #[parallel]
912    async fn credential_provider_full() -> TestResult {
913        let server = Server::run();
914        let scopes = vec!["scope1".to_string()];
915        let response = MDSTokenResponse {
916            access_token: "test-access-token".to_string(),
917            expires_in: None,
918            token_type: "test-token-type".to_string(),
919        };
920        server.expect(
921            Expectation::matching(all_of![
922                request::path(format!("{MDS_DEFAULT_URI}/token")),
923                request::query(url_decoded(contains(("scopes", scopes.join(",")))))
924            ])
925            .respond_with(json_encoded(response)),
926        );
927
928        let mdsc = Builder::default()
929            .with_endpoint(format!("http://{}", server.addr()))
930            .with_scopes(scopes)
931            .build()?;
932        let headers = mdsc.headers(Extensions::new()).await?;
933        assert_eq!(
934            get_token_from_headers(headers.clone()).unwrap(),
935            "test-access-token"
936        );
937        assert_eq!(
938            get_token_type_from_headers(headers).unwrap(),
939            "test-token-type"
940        );
941
942        Ok(())
943    }
944
945    #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
946    #[parallel]
947    async fn credentials_headers_retryable_error() -> TestResult {
948        let server = Server::run();
949        let scopes = vec!["scope1".to_string()];
950        server.expect(
951            Expectation::matching(all_of![
952                request::path(format!("{MDS_DEFAULT_URI}/token")),
953                request::query(url_decoded(contains(("scopes", scopes.join(",")))))
954            ])
955            .respond_with(status_code(503)),
956        );
957
958        let mdsc = Builder::default()
959            .with_endpoint(format!("http://{}", server.addr()))
960            .with_scopes(scopes)
961            .build()?;
962        let err = mdsc.headers(Extensions::new()).await.unwrap_err();
963        let original_err = find_source_error::<CredentialsError>(&err).unwrap();
964        assert!(original_err.is_transient());
965        let source = find_source_error::<reqwest::Error>(&err);
966        assert!(
967            matches!(source, Some(e) if e.status() == Some(StatusCode::SERVICE_UNAVAILABLE)),
968            "{err:?}"
969        );
970
971        Ok(())
972    }
973
974    #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
975    #[parallel]
976    async fn credentials_headers_nonretryable_error() -> TestResult {
977        let server = Server::run();
978        let scopes = vec!["scope1".to_string()];
979        server.expect(
980            Expectation::matching(all_of![
981                request::path(format!("{MDS_DEFAULT_URI}/token")),
982                request::query(url_decoded(contains(("scopes", scopes.join(",")))))
983            ])
984            .respond_with(status_code(401)),
985        );
986
987        let mdsc = Builder::default()
988            .with_endpoint(format!("http://{}", server.addr()))
989            .with_scopes(scopes)
990            .build()?;
991
992        let err = mdsc.headers(Extensions::new()).await.unwrap_err();
993        let original_err = find_source_error::<CredentialsError>(&err).unwrap();
994        assert!(!original_err.is_transient());
995        let source = find_source_error::<reqwest::Error>(&err);
996        assert!(
997            matches!(source, Some(e) if e.status() == Some(StatusCode::UNAUTHORIZED)),
998            "{err:?}"
999        );
1000
1001        Ok(())
1002    }
1003
1004    #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
1005    #[parallel]
1006    async fn credentials_headers_malformed_response_is_nonretryable() -> TestResult {
1007        let server = Server::run();
1008        let scopes = vec!["scope1".to_string()];
1009        server.expect(
1010            Expectation::matching(all_of![
1011                request::path(format!("{MDS_DEFAULT_URI}/token")),
1012                request::query(url_decoded(contains(("scopes", scopes.join(",")))))
1013            ])
1014            .respond_with(json_encoded("bad json")),
1015        );
1016
1017        let mdsc = Builder::default()
1018            .with_endpoint(format!("http://{}", server.addr()))
1019            .with_scopes(scopes)
1020            .build()?;
1021
1022        let e = mdsc.headers(Extensions::new()).await.err().unwrap();
1023        assert!(!e.is_transient());
1024
1025        Ok(())
1026    }
1027
1028    #[tokio::test]
1029    #[parallel]
1030    async fn get_default_universe_domain_success() -> TestResult {
1031        let universe_domain_response = Builder::default().build()?.universe_domain().await.unwrap();
1032        assert_eq!(universe_domain_response, DEFAULT_UNIVERSE_DOMAIN);
1033        Ok(())
1034    }
1035
1036    #[tokio::test]
1037    #[parallel]
1038    async fn get_mds_signer() -> TestResult {
1039        use base64::{Engine, prelude::BASE64_STANDARD};
1040        use serde_json::json;
1041
1042        let server = Server::run();
1043        server.expect(
1044            Expectation::matching(all_of![request::path(format!("{MDS_DEFAULT_URI}/token")),])
1045                .respond_with(json_encoded(MDSTokenResponse {
1046                    access_token: "test-access-token".to_string(),
1047                    expires_in: None,
1048                    token_type: "Bearer".to_string(),
1049                })),
1050        );
1051        server.expect(
1052            Expectation::matching(all_of![request::path(format!("{MDS_DEFAULT_URI}/email")),])
1053                .respond_with(status_code(200).body("test-client-email")),
1054        );
1055        server.expect(
1056            Expectation::matching(all_of![
1057                request::method_path(
1058                    "POST",
1059                    "/v1/projects/-/serviceAccounts/test-client-email:signBlob"
1060                ),
1061                request::headers(contains(("authorization", "Bearer test-access-token"))),
1062            ])
1063            .respond_with(json_encoded(json!({
1064                "signedBlob": BASE64_STANDARD.encode("signed_blob"),
1065            }))),
1066        );
1067
1068        let endpoint = server.url("").to_string().trim_end_matches('/').to_string();
1069
1070        let signer = Builder::default()
1071            .with_endpoint(&endpoint)
1072            .build_signer_with_iam_endpoint_override(Some(endpoint))?;
1073
1074        let client_email = signer.client_email().await?;
1075        assert_eq!(client_email, "test-client-email");
1076
1077        let signature = signer.sign(b"test").await?;
1078        assert_eq!(signature.as_ref(), b"signed_blob");
1079
1080        Ok(())
1081    }
1082}