Skip to main content

google_cloud_auth/credentials/
mds.rs

1// Copyright 2024 Google LLC
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     https://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15//! [Metadata Service] Credentials type.
16//!
17//! Google Cloud environments such as [Google Compute Engine (GCE)][gce-link],
18//! [Google Kubernetes Engine (GKE)][gke-link], or [Cloud Run] provide a metadata service.
19//! This is a local service to the VM (or pod) which (as the name implies) provides
20//! metadata information about the VM. The service also provides access
21//! tokens associated with the [default service account] for the corresponding
22//! VM.
23//!
24//! The default host name of the metadata service is `metadata.google.internal`.
25//! If you would like to use a different hostname, you can set it using the
26//! `GCE_METADATA_HOST` environment variable.
27//!
28//! You can use this access token to securely authenticate with Google Cloud,
29//! without having to download secrets or other credentials. The types in this
30//! module allow you to retrieve these access tokens, and can be used with
31//! the Google Cloud client libraries for Rust.
32//!
33//! ## Example: Creating credentials with a custom quota project
34//!
35//! ```
36//! # use google_cloud_auth::credentials::mds::Builder;
37//! # use google_cloud_auth::credentials::Credentials;
38//! # use http::Extensions;
39//! # async fn sample() -> anyhow::Result<()> {
40//! let credentials: Credentials = Builder::default()
41//!     .with_quota_project_id("my-quota-project")
42//!     .build()?;
43//! let headers = credentials.headers(Extensions::new()).await?;
44//! println!("Headers: {headers:?}");
45//! # Ok(()) }
46//! ```
47//!
48//! ## Example: Creating credentials with custom retry behavior
49//!
50//! ```
51//! # use google_cloud_auth::credentials::mds::Builder;
52//! # use google_cloud_auth::credentials::Credentials;
53//! # use http::Extensions;
54//! # use std::time::Duration;
55//! # async fn sample() -> anyhow::Result<()> {
56//! use google_cloud_gax::retry_policy::{AlwaysRetry, RetryPolicyExt};
57//! use google_cloud_gax::exponential_backoff::ExponentialBackoff;
58//! let backoff = ExponentialBackoff::default();
59//! let credentials: Credentials = Builder::default()
60//!     .with_retry_policy(AlwaysRetry.with_attempt_limit(3))
61//!     .with_backoff_policy(backoff)
62//!     .build()?;
63//! let headers = credentials.headers(Extensions::new()).await?;
64//! println!("Headers: {headers:?}");
65//! # Ok(()) }
66//! ```
67//!
68//! [Application Default Credentials]: https://cloud.google.com/docs/authentication/application-default-credentials
69//! [Cloud Run]: https://cloud.google.com/run
70//! [default service account]: https://cloud.google.com/iam/docs/service-account-types#default
71//! [gce-link]: https://cloud.google.com/products/compute
72//! [gke-link]: https://cloud.google.com/kubernetes-engine
73//! [Metadata Service]: https://cloud.google.com/compute/docs/metadata/overview
74
75use crate::access_boundary::CredentialsWithAccessBoundary;
76use crate::credentials::dynamic::{AccessTokenCredentialsProvider, CredentialsProvider};
77use crate::credentials::{AccessToken, AccessTokenCredentials, CacheableResource, Credentials};
78use crate::headers_util::AuthHeadersBuilder;
79use crate::mds::client::Client as MDSClient;
80use crate::retry::{Builder as RetryTokenProviderBuilder, TokenProviderWithRetry};
81use crate::token::{CachedTokenProvider, Token, TokenProvider};
82use crate::token_cache::TokenCache;
83use crate::{BuildResult, Result};
84use async_trait::async_trait;
85use google_cloud_gax::backoff_policy::BackoffPolicyArg;
86use google_cloud_gax::error::CredentialsError;
87use google_cloud_gax::retry_policy::RetryPolicyArg;
88use google_cloud_gax::retry_throttler::RetryThrottlerArg;
89use http::{Extensions, HeaderMap};
90use std::default::Default;
91use std::sync::Arc;
92
93// TODO(#2235) - Improve this message by talking about retries when really running with MDS
94const MDS_NOT_FOUND_ERROR: &str = concat!(
95    "Could not fetch an auth token to authenticate with Google Cloud. ",
96    "The most common reason for this problem is that you are not running in a Google Cloud Environment ",
97    "and you have not configured local credentials for development and testing. ",
98    "To setup local credentials, run `gcloud auth application-default login`. ",
99    "More information on how to authenticate client libraries can be found at https://cloud.google.com/docs/authentication/client-libraries"
100);
101
102#[derive(Debug)]
103struct MDSCredentials<T>
104where
105    T: CachedTokenProvider,
106{
107    quota_project_id: Option<String>,
108    token_provider: T,
109}
110
111/// Creates [Credentials] instances backed by the [Metadata Service].
112///
113/// While the Google Cloud client libraries for Rust default to credentials
114/// backed by the metadata service, some applications may need to:
115/// * Customize the metadata service credentials in some way
116/// * Bypass the [Application Default Credentials] lookup and only
117///   use the metadata server credentials
118/// * Use the credentials directly outside the client libraries
119///
120/// [Application Default Credentials]: https://cloud.google.com/docs/authentication/application-default-credentials
121/// [Metadata Service]: https://cloud.google.com/compute/docs/metadata/overview
122#[derive(Debug)]
123pub struct Builder {
124    endpoint: Option<String>,
125    quota_project_id: Option<String>,
126    scopes: Option<Vec<String>>,
127    created_by_adc: bool,
128    retry_builder: RetryTokenProviderBuilder,
129    iam_endpoint_override: Option<String>,
130    is_access_boundary_enabled: bool,
131}
132
133impl Default for Builder {
134    fn default() -> Self {
135        Self {
136            endpoint: None,
137            quota_project_id: None,
138            scopes: None,
139            created_by_adc: false,
140            retry_builder: RetryTokenProviderBuilder::default(),
141            iam_endpoint_override: None,
142            is_access_boundary_enabled: true,
143        }
144    }
145}
146
147impl Builder {
148    /// Sets the endpoint for this credentials.
149    ///
150    /// A trailing slash is significant, so specify the base URL without a trailing
151    /// slash. If not set, the credentials use `http://metadata.google.internal`.
152    ///
153    /// # Example
154    /// ```
155    /// # use google_cloud_auth::credentials::mds::Builder;
156    /// # async fn sample() -> anyhow::Result<()> {
157    /// let credentials = Builder::default()
158    ///     .with_endpoint("https://metadata.google.foobar")
159    ///     .build()?;
160    /// # Ok(()) }
161    /// ```
162    pub fn with_endpoint<S: Into<String>>(mut self, endpoint: S) -> Self {
163        self.endpoint = Some(endpoint.into());
164        self
165    }
166
167    /// Set the [quota project] for this credentials.
168    ///
169    /// In some services, you can use a service account in
170    /// one project for authentication and authorization, and charge
171    /// the usage to a different project. This may require that the
172    /// service account has `serviceusage.services.use` permissions on the quota project.
173    ///
174    /// [quota project]: https://cloud.google.com/docs/quotas/quota-project
175    pub fn with_quota_project_id<S: Into<String>>(mut self, quota_project_id: S) -> Self {
176        self.quota_project_id = Some(quota_project_id.into());
177        self
178    }
179
180    /// Sets the [scopes] for this credentials.
181    ///
182    /// Metadata server issues tokens based on the requested scopes.
183    /// If no scopes are specified, the credentials defaults to all
184    /// scopes configured for the [default service account] on the instance.
185    ///
186    /// [default service account]: https://cloud.google.com/iam/docs/service-account-types#default
187    /// [scopes]: https://developers.google.com/identity/protocols/oauth2/scopes
188    pub fn with_scopes<I, S>(mut self, scopes: I) -> Self
189    where
190        I: IntoIterator<Item = S>,
191        S: Into<String>,
192    {
193        self.scopes = Some(scopes.into_iter().map(|s| s.into()).collect());
194        self
195    }
196
197    /// Configure the retry policy for fetching tokens.
198    ///
199    /// The retry policy controls how to handle retries, and sets limits on
200    /// the number of attempts or the total time spent retrying.
201    ///
202    /// ```
203    /// # use google_cloud_auth::credentials::mds::Builder;
204    /// # async fn sample() -> anyhow::Result<()> {
205    /// use google_cloud_gax::retry_policy::{AlwaysRetry, RetryPolicyExt};
206    /// let credentials = Builder::default()
207    ///     .with_retry_policy(AlwaysRetry.with_attempt_limit(3))
208    ///     .build()?;
209    /// # Ok(()) }
210    /// ```
211    pub fn with_retry_policy<V: Into<RetryPolicyArg>>(mut self, v: V) -> Self {
212        self.retry_builder = self.retry_builder.with_retry_policy(v.into());
213        self
214    }
215
216    /// Configure the retry backoff policy.
217    ///
218    /// The backoff policy controls how long to wait in between retry attempts.
219    ///
220    /// ```
221    /// # use google_cloud_auth::credentials::mds::Builder;
222    /// # use std::time::Duration;
223    /// # async fn sample() -> anyhow::Result<()> {
224    /// use google_cloud_gax::exponential_backoff::ExponentialBackoff;
225    /// let policy = ExponentialBackoff::default();
226    /// let credentials = Builder::default()
227    ///     .with_backoff_policy(policy)
228    ///     .build()?;
229    /// # Ok(()) }
230    /// ```
231    pub fn with_backoff_policy<V: Into<BackoffPolicyArg>>(mut self, v: V) -> Self {
232        self.retry_builder = self.retry_builder.with_backoff_policy(v.into());
233        self
234    }
235
236    /// Configure the retry throttler.
237    ///
238    /// Advanced applications may want to configure a retry throttler to
239    /// [Address Cascading Failures] and when [Handling Overload] conditions.
240    /// The authentication library throttles its retry loop, using a policy to
241    /// control the throttling algorithm. Use this method to fine tune or
242    /// customize the default retry throttler.
243    ///
244    /// [Handling Overload]: https://sre.google/sre-book/handling-overload/
245    /// [Address Cascading Failures]: https://sre.google/sre-book/addressing-cascading-failures/
246    ///
247    /// ```
248    /// # use google_cloud_auth::credentials::mds::Builder;
249    /// # async fn sample() -> anyhow::Result<()> {
250    /// use google_cloud_gax::retry_throttler::AdaptiveThrottler;
251    /// let credentials = Builder::default()
252    ///     .with_retry_throttler(AdaptiveThrottler::default())
253    ///     .build()?;
254    /// # Ok(()) }
255    /// ```
256    pub fn with_retry_throttler<V: Into<RetryThrottlerArg>>(mut self, v: V) -> Self {
257        self.retry_builder = self.retry_builder.with_retry_throttler(v.into());
258        self
259    }
260
261    #[cfg(test)]
262    fn maybe_iam_endpoint_override(mut self, iam_endpoint_override: Option<String>) -> Self {
263        self.iam_endpoint_override = iam_endpoint_override;
264        self
265    }
266
267    #[cfg(test)]
268    fn without_access_boundary(mut self) -> Self {
269        self.is_access_boundary_enabled = false;
270        self
271    }
272
273    // This method is used to build mds credentials from ADC
274    pub(crate) fn from_adc() -> Self {
275        Self {
276            created_by_adc: true,
277            ..Default::default()
278        }
279    }
280
281    fn build_token_provider(self) -> TokenProviderWithRetry<MDSAccessTokenProvider> {
282        let tp = MDSAccessTokenProvider::builder()
283            .endpoint(self.endpoint)
284            .maybe_scopes(self.scopes)
285            .created_by_adc(self.created_by_adc)
286            .build();
287        self.retry_builder.build(tp)
288    }
289
290    /// Returns a [Credentials] instance with the configured settings.
291    pub fn build(self) -> BuildResult<Credentials> {
292        Ok(self.build_credentials()?.into())
293    }
294
295    /// Returns an [AccessTokenCredentials] instance with the configured settings.
296    ///
297    /// # Example
298    ///
299    /// ```
300    /// # use google_cloud_auth::credentials::mds::Builder;
301    /// # use google_cloud_auth::credentials::{AccessTokenCredentials, AccessTokenCredentialsProvider};
302    /// # async fn sample() -> anyhow::Result<()> {
303    /// let credentials: AccessTokenCredentials = Builder::default()
304    ///     .with_quota_project_id("my-quota-project")
305    ///     .build_access_token_credentials()?;
306    /// let access_token = credentials.access_token().await?;
307    /// println!("Token: {}", access_token.token);
308    /// # Ok(()) }
309    /// ```
310    pub fn build_access_token_credentials(self) -> BuildResult<AccessTokenCredentials> {
311        Ok(self.build_credentials()?.into())
312    }
313
314    fn build_credentials(
315        self,
316    ) -> BuildResult<CredentialsWithAccessBoundary<MDSCredentials<TokenCache>>> {
317        let iam_endpoint = self.iam_endpoint_override.clone();
318        let is_access_boundary_enabled = self.is_access_boundary_enabled;
319        let mds_client = MDSClient::new(self.endpoint.clone());
320        let mdsc = MDSCredentials {
321            quota_project_id: self.quota_project_id.clone(),
322            token_provider: TokenCache::new(self.build_token_provider()),
323        };
324        if !is_access_boundary_enabled {
325            return Ok(CredentialsWithAccessBoundary::new_no_op(mdsc));
326        }
327        Ok(CredentialsWithAccessBoundary::new_for_mds(
328            mdsc,
329            mds_client,
330            iam_endpoint,
331        ))
332    }
333
334    /// Returns a [crate::signer::Signer] instance with the configured settings.
335    ///
336    /// The returned [crate::signer::Signer] uses the [IAM signBlob API] to sign content. This API
337    /// requires a network request for each signing operation.
338    ///
339    /// # Example
340    ///
341    /// ```
342    /// # use google_cloud_auth::credentials::mds::Builder;
343    /// # use google_cloud_auth::signer::Signer;
344    /// # async fn sample() -> anyhow::Result<()> {
345    /// let signer: Signer = Builder::default().build_signer()?;
346    /// # Ok(()) }
347    /// ```
348    ///
349    /// [IAM signBlob API]: https://cloud.google.com/iam/docs/reference/credentials/rest/v1/projects.serviceAccounts/signBlob
350    pub fn build_signer(self) -> BuildResult<crate::signer::Signer> {
351        let client = MDSClient::new(self.endpoint.clone());
352        let iam_endpoint = self.iam_endpoint_override.clone();
353        let credentials = self.build()?;
354        let signing_provider = crate::signer::mds::MDSSigner::new(client, credentials);
355        let signing_provider = iam_endpoint
356            .iter()
357            .fold(signing_provider, |signing_provider, endpoint| {
358                signing_provider.with_iam_endpoint_override(endpoint)
359            });
360        Ok(crate::signer::Signer {
361            inner: Arc::new(signing_provider),
362        })
363    }
364}
365
366#[async_trait::async_trait]
367impl<T> CredentialsProvider for MDSCredentials<T>
368where
369    T: CachedTokenProvider,
370{
371    async fn headers(&self, extensions: Extensions) -> Result<CacheableResource<HeaderMap>> {
372        let token = self.token_provider.token(extensions).await?;
373
374        AuthHeadersBuilder::new(&token)
375            .maybe_quota_project_id(self.quota_project_id.as_deref())
376            .build()
377    }
378}
379
380#[async_trait::async_trait]
381impl<T> AccessTokenCredentialsProvider for MDSCredentials<T>
382where
383    T: CachedTokenProvider,
384{
385    async fn access_token(&self) -> Result<AccessToken> {
386        let token = self.token_provider.token(Extensions::new()).await?;
387        token.into()
388    }
389}
390
391#[derive(Debug, Default)]
392struct MDSAccessTokenProviderBuilder {
393    scopes: Option<Vec<String>>,
394    endpoint: Option<String>,
395    created_by_adc: bool,
396}
397
398impl MDSAccessTokenProviderBuilder {
399    fn build(self) -> MDSAccessTokenProvider {
400        MDSAccessTokenProvider {
401            client: MDSClient::new(self.endpoint),
402            scopes: self.scopes,
403            created_by_adc: self.created_by_adc,
404        }
405    }
406
407    fn maybe_scopes(mut self, v: Option<Vec<String>>) -> Self {
408        self.scopes = v;
409        self
410    }
411
412    fn endpoint<T>(mut self, v: Option<T>) -> Self
413    where
414        T: Into<String>,
415    {
416        self.endpoint = v.map(Into::into);
417        self
418    }
419
420    fn created_by_adc(mut self, v: bool) -> Self {
421        self.created_by_adc = v;
422        self
423    }
424}
425
426#[derive(Debug, Clone)]
427struct MDSAccessTokenProvider {
428    scopes: Option<Vec<String>>,
429    client: MDSClient,
430    created_by_adc: bool,
431}
432
433impl MDSAccessTokenProvider {
434    fn builder() -> MDSAccessTokenProviderBuilder {
435        MDSAccessTokenProviderBuilder::default()
436    }
437
438    // During ADC, if no credentials are found in the well-known location and the GOOGLE_APPLICATION_CREDENTIALS
439    // environment variable is not set, we default to MDS credentials without checking if the code is really
440    // running in an environment with MDS. To help users who got to this state because of lack of credentials
441    // setup on their machines, we provide a detailed error message to them talking about local setup and other
442    // auth mechanisms available to them.
443    // If the endpoint is overridden, even if ADC was used to create the MDS credentials, we do not give a detailed
444    // error message because they deliberately wanted to use an MDS.
445    fn error_message(&self) -> &str {
446        if self.use_adc_message() {
447            MDS_NOT_FOUND_ERROR
448        } else {
449            "failed to fetch token"
450        }
451    }
452
453    fn use_adc_message(&self) -> bool {
454        self.created_by_adc && self.client.is_default_endpoint
455    }
456}
457
458#[async_trait]
459impl TokenProvider for MDSAccessTokenProvider {
460    async fn token(&self) -> Result<Token> {
461        self.client
462            .access_token(self.scopes.clone())
463            .await
464            .map_err(|e| CredentialsError::new(e.is_transient(), self.error_message(), e))
465    }
466}
467
468#[cfg(test)]
469mod tests {
470    use super::*;
471    use crate::credentials::DEFAULT_UNIVERSE_DOMAIN;
472    use crate::credentials::QUOTA_PROJECT_KEY;
473    use crate::credentials::tests::{
474        find_source_error, get_headers_from_cache, get_mock_auth_retry_policy,
475        get_mock_backoff_policy, get_mock_retry_throttler, get_token_from_headers,
476        get_token_type_from_headers,
477    };
478    use crate::errors;
479    use crate::errors::CredentialsError;
480    use crate::mds::client::MDSTokenResponse;
481    use crate::mds::{GCE_METADATA_HOST_ENV_VAR, MDS_DEFAULT_URI, METADATA_ROOT};
482    use crate::token::tests::MockTokenProvider;
483    use base64::{Engine, prelude::BASE64_STANDARD};
484    use http::HeaderValue;
485    use http::header::AUTHORIZATION;
486    use httptest::cycle;
487    use httptest::matchers::{all_of, contains, request, url_decoded};
488    use httptest::responders::{json_encoded, status_code};
489    use httptest::{Expectation, Server};
490    use reqwest::StatusCode;
491    use scoped_env::ScopedEnv;
492    use serde_json::json;
493    use serial_test::{parallel, serial};
494    use std::error::Error;
495    use std::time::Duration;
496    use test_case::test_case;
497    use tokio::time::Instant;
498    use url::Url;
499
500    type TestResult = anyhow::Result<()>;
501
502    #[tokio::test]
503    #[parallel]
504    async fn test_mds_retries_on_transient_failures() -> TestResult {
505        let mut server = Server::run();
506        server.expect(
507            Expectation::matching(request::path(format!("{MDS_DEFAULT_URI}/token")))
508                .times(3)
509                .respond_with(status_code(503)),
510        );
511
512        let provider = Builder::default()
513            .with_endpoint(format!("http://{}", server.addr()))
514            .with_retry_policy(get_mock_auth_retry_policy(3))
515            .with_backoff_policy(get_mock_backoff_policy())
516            .with_retry_throttler(get_mock_retry_throttler())
517            .build_token_provider();
518
519        let err = provider.token().await.unwrap_err();
520        assert!(err.is_transient(), "{err:?}");
521        server.verify_and_clear();
522        Ok(())
523    }
524
525    #[tokio::test]
526    #[parallel]
527    async fn test_mds_does_not_retry_on_non_transient_failures() -> TestResult {
528        let mut server = Server::run();
529        server.expect(
530            Expectation::matching(request::path(format!("{MDS_DEFAULT_URI}/token")))
531                .times(1)
532                .respond_with(status_code(401)),
533        );
534
535        let provider = Builder::default()
536            .with_endpoint(format!("http://{}", server.addr()))
537            .with_retry_policy(get_mock_auth_retry_policy(1))
538            .with_backoff_policy(get_mock_backoff_policy())
539            .with_retry_throttler(get_mock_retry_throttler())
540            .build_token_provider();
541
542        let err = provider.token().await.unwrap_err();
543        assert!(!err.is_transient());
544        server.verify_and_clear();
545        Ok(())
546    }
547
548    #[tokio::test]
549    #[parallel]
550    async fn test_mds_retries_for_success() -> TestResult {
551        let mut server = Server::run();
552        let response = MDSTokenResponse {
553            access_token: "test-access-token".to_string(),
554            expires_in: Some(3600),
555            token_type: "test-token-type".to_string(),
556        };
557
558        server.expect(
559            Expectation::matching(request::path(format!("{MDS_DEFAULT_URI}/token")))
560                .times(3)
561                .respond_with(cycle![
562                    status_code(503).body("try-again"),
563                    status_code(503).body("try-again"),
564                    status_code(200)
565                        .append_header("Content-Type", "application/json")
566                        .body(serde_json::to_string(&response).unwrap()),
567                ]),
568        );
569
570        let provider = Builder::default()
571            .with_endpoint(format!("http://{}", server.addr()))
572            .with_retry_policy(get_mock_auth_retry_policy(3))
573            .with_backoff_policy(get_mock_backoff_policy())
574            .with_retry_throttler(get_mock_retry_throttler())
575            .build_token_provider();
576
577        let token = provider.token().await?;
578        assert_eq!(token.token, "test-access-token");
579
580        server.verify_and_clear();
581        Ok(())
582    }
583
584    #[test]
585    #[parallel]
586    fn validate_default_endpoint_urls() {
587        let default_endpoint_address = Url::parse(&format!("{METADATA_ROOT}{MDS_DEFAULT_URI}"));
588        assert!(
589            default_endpoint_address.is_ok(),
590            "{default_endpoint_address:?}"
591        );
592
593        let token_endpoint_address = Url::parse(&format!("{METADATA_ROOT}{MDS_DEFAULT_URI}/token"));
594        assert!(token_endpoint_address.is_ok(), "{token_endpoint_address:?}");
595    }
596
597    #[tokio::test]
598    #[parallel]
599    async fn headers_success() -> TestResult {
600        let token = Token {
601            token: "test-token".to_string(),
602            token_type: "Bearer".to_string(),
603            expires_at: None,
604            metadata: None,
605        };
606
607        let mut mock = MockTokenProvider::new();
608        mock.expect_token().times(1).return_once(|| Ok(token));
609
610        let mdsc = MDSCredentials {
611            quota_project_id: None,
612            token_provider: TokenCache::new(mock),
613        };
614
615        let mut extensions = Extensions::new();
616        let cached_headers = mdsc.headers(extensions.clone()).await.unwrap();
617        let (headers, entity_tag) = match cached_headers {
618            CacheableResource::New { entity_tag, data } => (data, entity_tag),
619            CacheableResource::NotModified => unreachable!("expecting new headers"),
620        };
621        let token = headers.get(AUTHORIZATION).unwrap();
622        assert_eq!(headers.len(), 1, "{headers:?}");
623        assert_eq!(token, HeaderValue::from_static("Bearer test-token"));
624        assert!(token.is_sensitive());
625
626        extensions.insert(entity_tag);
627
628        let cached_headers = mdsc.headers(extensions).await?;
629
630        match cached_headers {
631            CacheableResource::New { .. } => unreachable!("expecting new headers"),
632            CacheableResource::NotModified => CacheableResource::<HeaderMap>::NotModified,
633        };
634        Ok(())
635    }
636
637    #[tokio::test]
638    #[parallel]
639    async fn access_token_success() -> TestResult {
640        let server = Server::run();
641        let response = MDSTokenResponse {
642            access_token: "test-access-token".to_string(),
643            expires_in: Some(3600),
644            token_type: "Bearer".to_string(),
645        };
646        server.expect(
647            Expectation::matching(all_of![request::path(format!("{MDS_DEFAULT_URI}/token")),])
648                .respond_with(json_encoded(response)),
649        );
650
651        let creds = Builder::default()
652            .with_endpoint(format!("http://{}", server.addr()))
653            .without_access_boundary()
654            .build_access_token_credentials()
655            .unwrap();
656
657        let access_token = creds.access_token().await.unwrap();
658        assert_eq!(access_token.token, "test-access-token");
659
660        Ok(())
661    }
662
663    #[tokio::test]
664    #[parallel]
665    async fn headers_failure() {
666        let mut mock = MockTokenProvider::new();
667        mock.expect_token()
668            .times(1)
669            .return_once(|| Err(errors::non_retryable_from_str("fail")));
670
671        let mdsc = MDSCredentials {
672            quota_project_id: None,
673            token_provider: TokenCache::new(mock),
674        };
675        let result = mdsc.headers(Extensions::new()).await;
676        assert!(result.is_err(), "{result:?}");
677    }
678
679    #[test]
680    #[parallel]
681    fn error_message_with_adc() {
682        let provider = MDSAccessTokenProvider::builder()
683            .created_by_adc(true)
684            .build();
685
686        let want = MDS_NOT_FOUND_ERROR;
687        let got = provider.error_message();
688        assert!(got.contains(want), "{got}, {provider:?}");
689    }
690
691    #[test_case(false, false)]
692    #[test_case(false, true)]
693    #[test_case(true, true)]
694    fn error_message_without_adc(adc: bool, overridden: bool) {
695        let endpoint = if overridden {
696            Some("http://127.0.0.1")
697        } else {
698            None
699        };
700        let provider = MDSAccessTokenProvider::builder()
701            .endpoint(endpoint)
702            .created_by_adc(adc)
703            .build();
704
705        let not_want = MDS_NOT_FOUND_ERROR;
706        let got = provider.error_message();
707        assert!(!got.contains(not_want), "{got}, {provider:?}");
708    }
709
710    #[tokio::test]
711    #[serial]
712    async fn adc_no_mds() -> TestResult {
713        let Err(err) = Builder::from_adc().build_token_provider().token().await else {
714            // The environment has an MDS, skip the test.
715            return Ok(());
716        };
717
718        let original_err = find_source_error::<CredentialsError>(&err).unwrap();
719        assert!(
720            original_err.to_string().contains("application-default"),
721            "display={err}, debug={err:?}"
722        );
723
724        Ok(())
725    }
726
727    #[tokio::test]
728    #[serial]
729    async fn adc_overridden_mds() -> TestResult {
730        let _e = ScopedEnv::set(GCE_METADATA_HOST_ENV_VAR, "metadata.overridden");
731
732        let err = Builder::from_adc()
733            .build_token_provider()
734            .token()
735            .await
736            .unwrap_err();
737
738        let _e = ScopedEnv::remove(GCE_METADATA_HOST_ENV_VAR);
739
740        let original_err = find_source_error::<CredentialsError>(&err).unwrap();
741        assert!(original_err.is_transient());
742        assert!(
743            !original_err.to_string().contains("application-default"),
744            "display={err}, debug={err:?}"
745        );
746        let source = find_source_error::<reqwest::Error>(&err);
747        assert!(matches!(source, Some(e) if e.status().is_none()), "{err:?}");
748
749        Ok(())
750    }
751
752    #[tokio::test]
753    #[serial]
754    async fn builder_no_mds() -> TestResult {
755        let Err(e) = Builder::default().build_token_provider().token().await else {
756            // The environment has an MDS, skip the test.
757            return Ok(());
758        };
759
760        let original_err = find_source_error::<CredentialsError>(&e).unwrap();
761        assert!(
762            !format!("{:?}", original_err.source()).contains("application-default"),
763            "{e:?}"
764        );
765
766        Ok(())
767    }
768
769    #[tokio::test]
770    #[serial]
771    async fn test_gce_metadata_host_env_var() -> TestResult {
772        let server = Server::run();
773        let scopes = ["scope1", "scope2"];
774        let response = MDSTokenResponse {
775            access_token: "test-access-token".to_string(),
776            expires_in: Some(3600),
777            token_type: "test-token-type".to_string(),
778        };
779        server.expect(
780            Expectation::matching(all_of![
781                request::path(format!("{MDS_DEFAULT_URI}/token")),
782                request::query(url_decoded(contains(("scopes", scopes.join(",")))))
783            ])
784            .respond_with(json_encoded(response)),
785        );
786
787        let addr = server.addr().to_string();
788        let _e = ScopedEnv::set(GCE_METADATA_HOST_ENV_VAR, &addr);
789        let mdsc = Builder::default()
790            .with_scopes(["scope1", "scope2"])
791            .without_access_boundary()
792            .build()
793            .unwrap();
794        let headers = mdsc.headers(Extensions::new()).await.unwrap();
795        let _e = ScopedEnv::remove(GCE_METADATA_HOST_ENV_VAR);
796
797        assert_eq!(
798            get_token_from_headers(headers).unwrap(),
799            "test-access-token"
800        );
801        Ok(())
802    }
803
804    #[tokio::test]
805    #[parallel]
806    async fn headers_success_with_quota_project() -> TestResult {
807        let server = Server::run();
808        let scopes = ["scope1", "scope2"];
809        let response = MDSTokenResponse {
810            access_token: "test-access-token".to_string(),
811            expires_in: Some(3600),
812            token_type: "test-token-type".to_string(),
813        };
814        server.expect(
815            Expectation::matching(all_of![
816                request::path(format!("{MDS_DEFAULT_URI}/token")),
817                request::query(url_decoded(contains(("scopes", scopes.join(",")))))
818            ])
819            .respond_with(json_encoded(response)),
820        );
821
822        let mdsc = Builder::default()
823            .with_scopes(["scope1", "scope2"])
824            .with_endpoint(format!("http://{}", server.addr()))
825            .with_quota_project_id("test-project")
826            .without_access_boundary()
827            .build()?;
828
829        let headers = get_headers_from_cache(mdsc.headers(Extensions::new()).await.unwrap())?;
830        let token = headers.get(AUTHORIZATION).unwrap();
831        let quota_project = headers.get(QUOTA_PROJECT_KEY).unwrap();
832
833        assert_eq!(headers.len(), 2, "{headers:?}");
834        assert_eq!(
835            token,
836            HeaderValue::from_static("test-token-type test-access-token")
837        );
838        assert!(token.is_sensitive());
839        assert_eq!(quota_project, HeaderValue::from_static("test-project"));
840        assert!(!quota_project.is_sensitive());
841
842        Ok(())
843    }
844
845    #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
846    #[parallel]
847    async fn token_caching() -> TestResult {
848        let mut server = Server::run();
849        let scopes = vec!["scope1".to_string()];
850        let response = MDSTokenResponse {
851            access_token: "test-access-token".to_string(),
852            expires_in: Some(3600),
853            token_type: "test-token-type".to_string(),
854        };
855        server.expect(
856            Expectation::matching(all_of![
857                request::path(format!("{MDS_DEFAULT_URI}/token")),
858                request::query(url_decoded(contains(("scopes", scopes.join(",")))))
859            ])
860            .times(1)
861            .respond_with(json_encoded(response)),
862        );
863
864        let mdsc = Builder::default()
865            .with_scopes(scopes)
866            .with_endpoint(format!("http://{}", server.addr()))
867            .build()?;
868        let headers = mdsc.headers(Extensions::new()).await?;
869        assert_eq!(
870            get_token_from_headers(headers).unwrap(),
871            "test-access-token"
872        );
873        let headers = mdsc.headers(Extensions::new()).await?;
874        assert_eq!(
875            get_token_from_headers(headers).unwrap(),
876            "test-access-token"
877        );
878
879        // validate that the inner token provider is called only once
880        server.verify_and_clear();
881
882        Ok(())
883    }
884
885    #[tokio::test(start_paused = true)]
886    #[parallel]
887    async fn token_provider_full() -> TestResult {
888        let server = Server::run();
889        let scopes = vec!["scope1".to_string()];
890        let response = MDSTokenResponse {
891            access_token: "test-access-token".to_string(),
892            expires_in: Some(3600),
893            token_type: "test-token-type".to_string(),
894        };
895        server.expect(
896            Expectation::matching(all_of![
897                request::path(format!("{MDS_DEFAULT_URI}/token")),
898                request::query(url_decoded(contains(("scopes", scopes.join(",")))))
899            ])
900            .respond_with(json_encoded(response)),
901        );
902
903        let token = Builder::default()
904            .with_endpoint(format!("http://{}", server.addr()))
905            .with_scopes(scopes)
906            .build_token_provider()
907            .token()
908            .await?;
909
910        let now = tokio::time::Instant::now();
911        assert_eq!(token.token, "test-access-token");
912        assert_eq!(token.token_type, "test-token-type");
913        assert!(
914            token
915                .expires_at
916                .is_some_and(|d| d >= now + Duration::from_secs(3600))
917        );
918
919        Ok(())
920    }
921
922    #[tokio::test(start_paused = true)]
923    #[parallel]
924    async fn token_provider_full_no_scopes() -> TestResult {
925        let server = Server::run();
926        let response = MDSTokenResponse {
927            access_token: "test-access-token".to_string(),
928            expires_in: Some(3600),
929            token_type: "test-token-type".to_string(),
930        };
931        server.expect(
932            Expectation::matching(request::path(format!("{MDS_DEFAULT_URI}/token")))
933                .respond_with(json_encoded(response)),
934        );
935
936        let token = Builder::default()
937            .with_endpoint(format!("http://{}", server.addr()))
938            .build_token_provider()
939            .token()
940            .await?;
941
942        let now = Instant::now();
943        assert_eq!(token.token, "test-access-token");
944        assert_eq!(token.token_type, "test-token-type");
945        assert!(
946            token
947                .expires_at
948                .is_some_and(|d| d == now + Duration::from_secs(3600))
949        );
950
951        Ok(())
952    }
953
954    #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
955    #[parallel]
956    async fn credential_provider_full() -> TestResult {
957        let server = Server::run();
958        let scopes = vec!["scope1".to_string()];
959        let response = MDSTokenResponse {
960            access_token: "test-access-token".to_string(),
961            expires_in: None,
962            token_type: "test-token-type".to_string(),
963        };
964        server.expect(
965            Expectation::matching(all_of![
966                request::path(format!("{MDS_DEFAULT_URI}/token")),
967                request::query(url_decoded(contains(("scopes", scopes.join(",")))))
968            ])
969            .respond_with(json_encoded(response)),
970        );
971
972        let mdsc = Builder::default()
973            .with_endpoint(format!("http://{}", server.addr()))
974            .with_scopes(scopes)
975            .without_access_boundary()
976            .build()?;
977        let headers = mdsc.headers(Extensions::new()).await?;
978        assert_eq!(
979            get_token_from_headers(headers.clone()).unwrap(),
980            "test-access-token"
981        );
982        assert_eq!(
983            get_token_type_from_headers(headers).unwrap(),
984            "test-token-type"
985        );
986
987        Ok(())
988    }
989
990    #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
991    #[parallel]
992    async fn credentials_headers_retryable_error() -> TestResult {
993        let server = Server::run();
994        let scopes = vec!["scope1".to_string()];
995        server.expect(
996            Expectation::matching(all_of![
997                request::path(format!("{MDS_DEFAULT_URI}/token")),
998                request::query(url_decoded(contains(("scopes", scopes.join(",")))))
999            ])
1000            .respond_with(status_code(503)),
1001        );
1002
1003        let mdsc = Builder::default()
1004            .with_endpoint(format!("http://{}", server.addr()))
1005            .with_scopes(scopes)
1006            .without_access_boundary()
1007            .build()?;
1008        let err = mdsc.headers(Extensions::new()).await.unwrap_err();
1009        let original_err = find_source_error::<CredentialsError>(&err).unwrap();
1010        assert!(original_err.is_transient());
1011        let source = find_source_error::<reqwest::Error>(&err);
1012        assert!(
1013            matches!(source, Some(e) if e.status() == Some(StatusCode::SERVICE_UNAVAILABLE)),
1014            "{err:?}"
1015        );
1016
1017        Ok(())
1018    }
1019
1020    #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
1021    #[parallel]
1022    async fn credentials_headers_nonretryable_error() -> TestResult {
1023        let server = Server::run();
1024        let scopes = vec!["scope1".to_string()];
1025        server.expect(
1026            Expectation::matching(all_of![
1027                request::path(format!("{MDS_DEFAULT_URI}/token")),
1028                request::query(url_decoded(contains(("scopes", scopes.join(",")))))
1029            ])
1030            .respond_with(status_code(401)),
1031        );
1032
1033        let mdsc = Builder::default()
1034            .with_endpoint(format!("http://{}", server.addr()))
1035            .with_scopes(scopes)
1036            .without_access_boundary()
1037            .build()?;
1038
1039        let err = mdsc.headers(Extensions::new()).await.unwrap_err();
1040        let original_err = find_source_error::<CredentialsError>(&err).unwrap();
1041        assert!(!original_err.is_transient());
1042        let source = find_source_error::<reqwest::Error>(&err);
1043        assert!(
1044            matches!(source, Some(e) if e.status() == Some(StatusCode::UNAUTHORIZED)),
1045            "{err:?}"
1046        );
1047
1048        Ok(())
1049    }
1050
1051    #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
1052    #[parallel]
1053    async fn credentials_headers_malformed_response_is_nonretryable() -> TestResult {
1054        let server = Server::run();
1055        let scopes = vec!["scope1".to_string()];
1056        server.expect(
1057            Expectation::matching(all_of![
1058                request::path(format!("{MDS_DEFAULT_URI}/token")),
1059                request::query(url_decoded(contains(("scopes", scopes.join(",")))))
1060            ])
1061            .respond_with(json_encoded("bad json")),
1062        );
1063
1064        let mdsc = Builder::default()
1065            .with_endpoint(format!("http://{}", server.addr()))
1066            .with_scopes(scopes)
1067            .without_access_boundary()
1068            .build()?;
1069
1070        let e = mdsc.headers(Extensions::new()).await.err().unwrap();
1071        assert!(!e.is_transient());
1072
1073        Ok(())
1074    }
1075
1076    #[tokio::test]
1077    #[parallel]
1078    async fn get_default_universe_domain_success() -> TestResult {
1079        let universe_domain_response = Builder::default().build()?.universe_domain().await.unwrap();
1080        assert_eq!(universe_domain_response, DEFAULT_UNIVERSE_DOMAIN);
1081        Ok(())
1082    }
1083
1084    #[tokio::test]
1085    #[parallel]
1086    async fn get_mds_signer() -> TestResult {
1087        let server = Server::run();
1088        server.expect(
1089            Expectation::matching(all_of![request::path(format!("{MDS_DEFAULT_URI}/token")),])
1090                .respond_with(json_encoded(MDSTokenResponse {
1091                    access_token: "test-access-token".to_string(),
1092                    expires_in: None,
1093                    token_type: "Bearer".to_string(),
1094                })),
1095        );
1096        server.expect(
1097            Expectation::matching(all_of![request::path(format!("{MDS_DEFAULT_URI}/email")),])
1098                .respond_with(status_code(200).body("test-client-email")),
1099        );
1100        server.expect(
1101            Expectation::matching(all_of![
1102                request::method_path(
1103                    "POST",
1104                    "/v1/projects/-/serviceAccounts/test-client-email:signBlob"
1105                ),
1106                request::headers(contains(("authorization", "Bearer test-access-token"))),
1107            ])
1108            .respond_with(json_encoded(json!({
1109                "signedBlob": BASE64_STANDARD.encode("signed_blob"),
1110            }))),
1111        );
1112
1113        let endpoint = server.url("").to_string().trim_end_matches('/').to_string();
1114
1115        let signer = Builder::default()
1116            .with_endpoint(&endpoint)
1117            .maybe_iam_endpoint_override(Some(endpoint))
1118            .without_access_boundary()
1119            .build_signer()?;
1120
1121        let client_email = signer.client_email().await?;
1122        assert_eq!(client_email, "test-client-email");
1123
1124        let signature = signer.sign(b"test").await?;
1125        assert_eq!(signature.as_ref(), b"signed_blob");
1126
1127        Ok(())
1128    }
1129
1130    #[tokio::test]
1131    #[parallel]
1132    #[cfg(google_cloud_unstable_trusted_boundaries)]
1133    async fn e2e_access_boundary() -> TestResult {
1134        use crate::credentials::tests::get_access_boundary_from_headers;
1135
1136        let server = Server::run();
1137        server.expect(
1138            Expectation::matching(all_of![request::path(format!("{MDS_DEFAULT_URI}/token")),])
1139                .respond_with(json_encoded(MDSTokenResponse {
1140                    access_token: "test-access-token".to_string(),
1141                    expires_in: None,
1142                    token_type: "Bearer".to_string(),
1143                })),
1144        );
1145        server.expect(
1146            Expectation::matching(all_of![request::path(format!("{MDS_DEFAULT_URI}/email")),])
1147                .respond_with(status_code(200).body("test-client-email")),
1148        );
1149        server.expect(
1150            Expectation::matching(all_of![
1151                request::method_path(
1152                    "GET",
1153                    "/v1/projects/-/serviceAccounts/test-client-email/allowedLocations"
1154                ),
1155                request::headers(contains(("authorization", "Bearer test-access-token"))),
1156            ])
1157            .respond_with(json_encoded(json!({
1158                "locations": ["us-central1", "us-east1"],
1159                "encodedLocations": "0x1234"
1160            }))),
1161        );
1162
1163        let endpoint = server.url("").to_string().trim_end_matches('/').to_string();
1164
1165        let creds = Builder::default()
1166            .with_endpoint(&endpoint)
1167            .maybe_iam_endpoint_override(Some(endpoint))
1168            .build_credentials()?;
1169
1170        // let the access boundary background thread update
1171        creds.wait_for_boundary().await;
1172
1173        let headers = creds.headers(Extensions::new()).await?;
1174        let token = get_token_from_headers(headers.clone());
1175        let access_boundary = get_access_boundary_from_headers(headers);
1176        assert!(token.is_some(), "should have some token: {token:?}");
1177        assert_eq!(
1178            access_boundary.as_deref(),
1179            Some("0x1234"),
1180            "should be 0x1234 but found: {access_boundary:?}"
1181        );
1182
1183        Ok(())
1184    }
1185}