Skip to main content

google_cloud_auth/credentials/
mds.rs

1// Copyright 2024 Google LLC
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     https://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15//! [Metadata Service] Credentials type.
16//!
17//! Google Cloud environments such as [Google Compute Engine (GCE)][gce-link],
18//! [Google Kubernetes Engine (GKE)][gke-link], or [Cloud Run] provide a metadata service.
19//! This is a local service to the VM (or pod) which (as the name implies) provides
20//! metadata information about the VM. The service also provides access
21//! tokens associated with the [default service account] for the corresponding
22//! VM.
23//!
24//! The default host name of the metadata service is `metadata.google.internal`.
25//! If you would like to use a different hostname, you can set it using the
26//! `GCE_METADATA_HOST` environment variable.
27//!
28//! You can use this access token to securely authenticate with Google Cloud,
29//! without having to download secrets or other credentials. The types in this
30//! module allow you to retrieve these access tokens, and can be used with
31//! the Google Cloud client libraries for Rust.
32//!
33//! ## Example: Creating credentials with a custom quota project
34//!
35//! ```
36//! # use google_cloud_auth::credentials::mds::Builder;
37//! # use google_cloud_auth::credentials::Credentials;
38//! # use http::Extensions;
39//! # tokio_test::block_on(async {
40//! let credentials: Credentials = Builder::default()
41//!     .with_quota_project_id("my-quota-project")
42//!     .build()?;
43//! let headers = credentials.headers(Extensions::new()).await?;
44//! println!("Headers: {headers:?}");
45//! # Ok::<(), anyhow::Error>(())
46//! # });
47//! ```
48//!
49//! ## Example: Creating credentials with custom retry behavior
50//!
51//! ```
52//! # use google_cloud_auth::credentials::mds::Builder;
53//! # use google_cloud_auth::credentials::Credentials;
54//! # use http::Extensions;
55//! # use std::time::Duration;
56//! # tokio_test::block_on(async {
57//! use google_cloud_gax::retry_policy::{AlwaysRetry, RetryPolicyExt};
58//! use google_cloud_gax::exponential_backoff::ExponentialBackoff;
59//! let backoff = ExponentialBackoff::default();
60//! let credentials: Credentials = Builder::default()
61//!     .with_retry_policy(AlwaysRetry.with_attempt_limit(3))
62//!     .with_backoff_policy(backoff)
63//!     .build()?;
64//! let headers = credentials.headers(Extensions::new()).await?;
65//! println!("Headers: {headers:?}");
66//! # Ok::<(), anyhow::Error>(())
67//! # });
68//! ```
69//!
70//! [Application Default Credentials]: https://cloud.google.com/docs/authentication/application-default-credentials
71//! [Cloud Run]: https://cloud.google.com/run
72//! [default service account]: https://cloud.google.com/iam/docs/service-account-types#default
73//! [gce-link]: https://cloud.google.com/products/compute
74//! [gke-link]: https://cloud.google.com/kubernetes-engine
75//! [Metadata Service]: https://cloud.google.com/compute/docs/metadata/overview
76
77use crate::access_boundary::CredentialsWithAccessBoundary;
78use crate::credentials::dynamic::{AccessTokenCredentialsProvider, CredentialsProvider};
79use crate::credentials::{AccessToken, AccessTokenCredentials, CacheableResource, Credentials};
80use crate::headers_util::AuthHeadersBuilder;
81use crate::mds::client::Client as MDSClient;
82use crate::retry::{Builder as RetryTokenProviderBuilder, TokenProviderWithRetry};
83use crate::token::{CachedTokenProvider, Token, TokenProvider};
84use crate::token_cache::TokenCache;
85use crate::{BuildResult, Result};
86use async_trait::async_trait;
87use google_cloud_gax::backoff_policy::BackoffPolicyArg;
88use google_cloud_gax::error::CredentialsError;
89use google_cloud_gax::retry_policy::RetryPolicyArg;
90use google_cloud_gax::retry_throttler::RetryThrottlerArg;
91use http::{Extensions, HeaderMap};
92use std::default::Default;
93use std::sync::Arc;
94
95// TODO(#2235) - Improve this message by talking about retries when really running with MDS
96const MDS_NOT_FOUND_ERROR: &str = concat!(
97    "Could not fetch an auth token to authenticate with Google Cloud. ",
98    "The most common reason for this problem is that you are not running in a Google Cloud Environment ",
99    "and you have not configured local credentials for development and testing. ",
100    "To setup local credentials, run `gcloud auth application-default login`. ",
101    "More information on how to authenticate client libraries can be found at https://cloud.google.com/docs/authentication/client-libraries"
102);
103
104#[derive(Debug)]
105struct MDSCredentials<T>
106where
107    T: CachedTokenProvider,
108{
109    quota_project_id: Option<String>,
110    token_provider: T,
111}
112
113/// Creates [Credentials] instances backed by the [Metadata Service].
114///
115/// While the Google Cloud client libraries for Rust default to credentials
116/// backed by the metadata service, some applications may need to:
117/// * Customize the metadata service credentials in some way
118/// * Bypass the [Application Default Credentials] lookup and only
119///   use the metadata server credentials
120/// * Use the credentials directly outside the client libraries
121///
122/// [Application Default Credentials]: https://cloud.google.com/docs/authentication/application-default-credentials
123/// [Metadata Service]: https://cloud.google.com/compute/docs/metadata/overview
124#[derive(Debug)]
125pub struct Builder {
126    endpoint: Option<String>,
127    quota_project_id: Option<String>,
128    scopes: Option<Vec<String>>,
129    created_by_adc: bool,
130    retry_builder: RetryTokenProviderBuilder,
131    iam_endpoint_override: Option<String>,
132    is_access_boundary_enabled: bool,
133}
134
135impl Default for Builder {
136    fn default() -> Self {
137        Self {
138            endpoint: None,
139            quota_project_id: None,
140            scopes: None,
141            created_by_adc: false,
142            retry_builder: RetryTokenProviderBuilder::default(),
143            iam_endpoint_override: None,
144            is_access_boundary_enabled: true,
145        }
146    }
147}
148
149impl Builder {
150    /// Sets the endpoint for this credentials.
151    ///
152    /// A trailing slash is significant, so specify the base URL without a trailing
153    /// slash. If not set, the credentials use `http://metadata.google.internal`.
154    ///
155    /// # Example
156    /// ```
157    /// # use google_cloud_auth::credentials::mds::Builder;
158    /// # tokio_test::block_on(async {
159    /// let credentials = Builder::default()
160    ///     .with_endpoint("https://metadata.google.foobar")
161    ///     .build();
162    /// # });
163    /// ```
164    pub fn with_endpoint<S: Into<String>>(mut self, endpoint: S) -> Self {
165        self.endpoint = Some(endpoint.into());
166        self
167    }
168
169    /// Set the [quota project] for this credentials.
170    ///
171    /// In some services, you can use a service account in
172    /// one project for authentication and authorization, and charge
173    /// the usage to a different project. This may require that the
174    /// service account has `serviceusage.services.use` permissions on the quota project.
175    ///
176    /// [quota project]: https://cloud.google.com/docs/quotas/quota-project
177    pub fn with_quota_project_id<S: Into<String>>(mut self, quota_project_id: S) -> Self {
178        self.quota_project_id = Some(quota_project_id.into());
179        self
180    }
181
182    /// Sets the [scopes] for this credentials.
183    ///
184    /// Metadata server issues tokens based on the requested scopes.
185    /// If no scopes are specified, the credentials defaults to all
186    /// scopes configured for the [default service account] on the instance.
187    ///
188    /// [default service account]: https://cloud.google.com/iam/docs/service-account-types#default
189    /// [scopes]: https://developers.google.com/identity/protocols/oauth2/scopes
190    pub fn with_scopes<I, S>(mut self, scopes: I) -> Self
191    where
192        I: IntoIterator<Item = S>,
193        S: Into<String>,
194    {
195        self.scopes = Some(scopes.into_iter().map(|s| s.into()).collect());
196        self
197    }
198
199    /// Configure the retry policy for fetching tokens.
200    ///
201    /// The retry policy controls how to handle retries, and sets limits on
202    /// the number of attempts or the total time spent retrying.
203    ///
204    /// ```
205    /// # use google_cloud_auth::credentials::mds::Builder;
206    /// # tokio_test::block_on(async {
207    /// use google_cloud_gax::retry_policy::{AlwaysRetry, RetryPolicyExt};
208    /// let credentials = Builder::default()
209    ///     .with_retry_policy(AlwaysRetry.with_attempt_limit(3))
210    ///     .build();
211    /// # });
212    /// ```
213    pub fn with_retry_policy<V: Into<RetryPolicyArg>>(mut self, v: V) -> Self {
214        self.retry_builder = self.retry_builder.with_retry_policy(v.into());
215        self
216    }
217
218    /// Configure the retry backoff policy.
219    ///
220    /// The backoff policy controls how long to wait in between retry attempts.
221    ///
222    /// ```
223    /// # use google_cloud_auth::credentials::mds::Builder;
224    /// # use std::time::Duration;
225    /// # tokio_test::block_on(async {
226    /// use google_cloud_gax::exponential_backoff::ExponentialBackoff;
227    /// let policy = ExponentialBackoff::default();
228    /// let credentials = Builder::default()
229    ///     .with_backoff_policy(policy)
230    ///     .build();
231    /// # });
232    /// ```
233    pub fn with_backoff_policy<V: Into<BackoffPolicyArg>>(mut self, v: V) -> Self {
234        self.retry_builder = self.retry_builder.with_backoff_policy(v.into());
235        self
236    }
237
238    /// Configure the retry throttler.
239    ///
240    /// Advanced applications may want to configure a retry throttler to
241    /// [Address Cascading Failures] and when [Handling Overload] conditions.
242    /// The authentication library throttles its retry loop, using a policy to
243    /// control the throttling algorithm. Use this method to fine tune or
244    /// customize the default retry throttler.
245    ///
246    /// [Handling Overload]: https://sre.google/sre-book/handling-overload/
247    /// [Address Cascading Failures]: https://sre.google/sre-book/addressing-cascading-failures/
248    ///
249    /// ```
250    /// # use google_cloud_auth::credentials::mds::Builder;
251    /// # tokio_test::block_on(async {
252    /// use google_cloud_gax::retry_throttler::AdaptiveThrottler;
253    /// let credentials = Builder::default()
254    ///     .with_retry_throttler(AdaptiveThrottler::default())
255    ///     .build();
256    /// # });
257    /// ```
258    pub fn with_retry_throttler<V: Into<RetryThrottlerArg>>(mut self, v: V) -> Self {
259        self.retry_builder = self.retry_builder.with_retry_throttler(v.into());
260        self
261    }
262
263    #[cfg(test)]
264    fn maybe_iam_endpoint_override(mut self, iam_endpoint_override: Option<String>) -> Self {
265        self.iam_endpoint_override = iam_endpoint_override;
266        self
267    }
268
269    #[cfg(test)]
270    fn without_access_boundary(mut self) -> Self {
271        self.is_access_boundary_enabled = false;
272        self
273    }
274
275    // This method is used to build mds credentials from ADC
276    pub(crate) fn from_adc() -> Self {
277        Self {
278            created_by_adc: true,
279            ..Default::default()
280        }
281    }
282
283    fn build_token_provider(self) -> TokenProviderWithRetry<MDSAccessTokenProvider> {
284        let tp = MDSAccessTokenProvider::builder()
285            .endpoint(self.endpoint)
286            .maybe_scopes(self.scopes)
287            .created_by_adc(self.created_by_adc)
288            .build();
289        self.retry_builder.build(tp)
290    }
291
292    /// Returns a [Credentials] instance with the configured settings.
293    pub fn build(self) -> BuildResult<Credentials> {
294        Ok(self.build_credentials()?.into())
295    }
296
297    /// Returns an [AccessTokenCredentials] instance with the configured settings.
298    ///
299    /// # Example
300    ///
301    /// ```
302    /// # use google_cloud_auth::credentials::mds::Builder;
303    /// # use google_cloud_auth::credentials::{AccessTokenCredentials, AccessTokenCredentialsProvider};
304    /// # tokio_test::block_on(async {
305    /// let credentials: AccessTokenCredentials = Builder::default()
306    ///     .with_quota_project_id("my-quota-project")
307    ///     .build_access_token_credentials()?;
308    /// let access_token = credentials.access_token().await?;
309    /// println!("Token: {}", access_token.token);
310    /// # Ok::<(), anyhow::Error>(())
311    /// # });
312    /// ```
313    pub fn build_access_token_credentials(self) -> BuildResult<AccessTokenCredentials> {
314        Ok(self.build_credentials()?.into())
315    }
316
317    fn build_credentials(
318        self,
319    ) -> BuildResult<CredentialsWithAccessBoundary<MDSCredentials<TokenCache>>> {
320        let iam_endpoint = self.iam_endpoint_override.clone();
321        let is_access_boundary_enabled = self.is_access_boundary_enabled;
322        let mds_client = MDSClient::new(self.endpoint.clone());
323        let mdsc = MDSCredentials {
324            quota_project_id: self.quota_project_id.clone(),
325            token_provider: TokenCache::new(self.build_token_provider()),
326        };
327        if !is_access_boundary_enabled {
328            return Ok(CredentialsWithAccessBoundary::new_no_op(mdsc));
329        }
330        Ok(CredentialsWithAccessBoundary::new_for_mds(
331            mdsc,
332            mds_client,
333            iam_endpoint,
334        ))
335    }
336
337    /// Returns a [crate::signer::Signer] instance with the configured settings.
338    ///
339    /// The returned [crate::signer::Signer] uses the [IAM signBlob API] to sign content. This API
340    /// requires a network request for each signing operation.
341    ///
342    /// # Example
343    ///
344    /// ```
345    /// # use google_cloud_auth::credentials::mds::Builder;
346    /// # use google_cloud_auth::signer::Signer;
347    /// # tokio_test::block_on(async {
348    /// let signer: Signer = Builder::default().build_signer()?;
349    /// # Ok::<(), anyhow::Error>(())
350    /// # });
351    /// ```
352    ///
353    /// [IAM signBlob API]: https://cloud.google.com/iam/docs/reference/credentials/rest/v1/projects.serviceAccounts/signBlob
354    pub fn build_signer(self) -> BuildResult<crate::signer::Signer> {
355        let client = MDSClient::new(self.endpoint.clone());
356        let iam_endpoint = self.iam_endpoint_override.clone();
357        let credentials = self.build()?;
358        let signing_provider = crate::signer::mds::MDSSigner::new(client, credentials);
359        let signing_provider = iam_endpoint
360            .iter()
361            .fold(signing_provider, |signing_provider, endpoint| {
362                signing_provider.with_iam_endpoint_override(endpoint)
363            });
364        Ok(crate::signer::Signer {
365            inner: Arc::new(signing_provider),
366        })
367    }
368}
369
370#[async_trait::async_trait]
371impl<T> CredentialsProvider for MDSCredentials<T>
372where
373    T: CachedTokenProvider,
374{
375    async fn headers(&self, extensions: Extensions) -> Result<CacheableResource<HeaderMap>> {
376        let token = self.token_provider.token(extensions).await?;
377
378        AuthHeadersBuilder::new(&token)
379            .maybe_quota_project_id(self.quota_project_id.as_deref())
380            .build()
381    }
382}
383
384#[async_trait::async_trait]
385impl<T> AccessTokenCredentialsProvider for MDSCredentials<T>
386where
387    T: CachedTokenProvider,
388{
389    async fn access_token(&self) -> Result<AccessToken> {
390        let token = self.token_provider.token(Extensions::new()).await?;
391        token.into()
392    }
393}
394
395#[derive(Debug, Default)]
396struct MDSAccessTokenProviderBuilder {
397    scopes: Option<Vec<String>>,
398    endpoint: Option<String>,
399    created_by_adc: bool,
400}
401
402impl MDSAccessTokenProviderBuilder {
403    fn build(self) -> MDSAccessTokenProvider {
404        MDSAccessTokenProvider {
405            client: MDSClient::new(self.endpoint),
406            scopes: self.scopes,
407            created_by_adc: self.created_by_adc,
408        }
409    }
410
411    fn maybe_scopes(mut self, v: Option<Vec<String>>) -> Self {
412        self.scopes = v;
413        self
414    }
415
416    fn endpoint<T>(mut self, v: Option<T>) -> Self
417    where
418        T: Into<String>,
419    {
420        self.endpoint = v.map(Into::into);
421        self
422    }
423
424    fn created_by_adc(mut self, v: bool) -> Self {
425        self.created_by_adc = v;
426        self
427    }
428}
429
430#[derive(Debug, Clone)]
431struct MDSAccessTokenProvider {
432    scopes: Option<Vec<String>>,
433    client: MDSClient,
434    created_by_adc: bool,
435}
436
437impl MDSAccessTokenProvider {
438    fn builder() -> MDSAccessTokenProviderBuilder {
439        MDSAccessTokenProviderBuilder::default()
440    }
441
442    // During ADC, if no credentials are found in the well-known location and the GOOGLE_APPLICATION_CREDENTIALS
443    // environment variable is not set, we default to MDS credentials without checking if the code is really
444    // running in an environment with MDS. To help users who got to this state because of lack of credentials
445    // setup on their machines, we provide a detailed error message to them talking about local setup and other
446    // auth mechanisms available to them.
447    // If the endpoint is overridden, even if ADC was used to create the MDS credentials, we do not give a detailed
448    // error message because they deliberately wanted to use an MDS.
449    fn error_message(&self) -> &str {
450        if self.use_adc_message() {
451            MDS_NOT_FOUND_ERROR
452        } else {
453            "failed to fetch token"
454        }
455    }
456
457    fn use_adc_message(&self) -> bool {
458        self.created_by_adc && self.client.is_default_endpoint
459    }
460}
461
462#[async_trait]
463impl TokenProvider for MDSAccessTokenProvider {
464    async fn token(&self) -> Result<Token> {
465        self.client
466            .access_token(self.scopes.clone())
467            .await
468            .map_err(|e| CredentialsError::new(e.is_transient(), self.error_message(), e))
469    }
470}
471
472#[cfg(test)]
473mod tests {
474    use super::*;
475    use crate::credentials::DEFAULT_UNIVERSE_DOMAIN;
476    use crate::credentials::QUOTA_PROJECT_KEY;
477    use crate::credentials::tests::{
478        find_source_error, get_headers_from_cache, get_mock_auth_retry_policy,
479        get_mock_backoff_policy, get_mock_retry_throttler, get_token_from_headers,
480        get_token_type_from_headers,
481    };
482    use crate::errors;
483    use crate::errors::CredentialsError;
484    use crate::mds::client::MDSTokenResponse;
485    use crate::mds::{GCE_METADATA_HOST_ENV_VAR, MDS_DEFAULT_URI, METADATA_ROOT};
486    use crate::token::tests::MockTokenProvider;
487    use base64::{Engine, prelude::BASE64_STANDARD};
488    use http::HeaderValue;
489    use http::header::AUTHORIZATION;
490    use httptest::cycle;
491    use httptest::matchers::{all_of, contains, request, url_decoded};
492    use httptest::responders::{json_encoded, status_code};
493    use httptest::{Expectation, Server};
494    use reqwest::StatusCode;
495    use scoped_env::ScopedEnv;
496    use serde_json::json;
497    use serial_test::{parallel, serial};
498    use std::error::Error;
499    use std::time::Duration;
500    use test_case::test_case;
501    use tokio::time::Instant;
502    use url::Url;
503
504    type TestResult = anyhow::Result<()>;
505
506    #[tokio::test]
507    #[parallel]
508    async fn test_mds_retries_on_transient_failures() -> TestResult {
509        let mut server = Server::run();
510        server.expect(
511            Expectation::matching(request::path(format!("{MDS_DEFAULT_URI}/token")))
512                .times(3)
513                .respond_with(status_code(503)),
514        );
515
516        let provider = Builder::default()
517            .with_endpoint(format!("http://{}", server.addr()))
518            .with_retry_policy(get_mock_auth_retry_policy(3))
519            .with_backoff_policy(get_mock_backoff_policy())
520            .with_retry_throttler(get_mock_retry_throttler())
521            .build_token_provider();
522
523        let err = provider.token().await.unwrap_err();
524        assert!(err.is_transient(), "{err:?}");
525        server.verify_and_clear();
526        Ok(())
527    }
528
529    #[tokio::test]
530    #[parallel]
531    async fn test_mds_does_not_retry_on_non_transient_failures() -> TestResult {
532        let mut server = Server::run();
533        server.expect(
534            Expectation::matching(request::path(format!("{MDS_DEFAULT_URI}/token")))
535                .times(1)
536                .respond_with(status_code(401)),
537        );
538
539        let provider = Builder::default()
540            .with_endpoint(format!("http://{}", server.addr()))
541            .with_retry_policy(get_mock_auth_retry_policy(1))
542            .with_backoff_policy(get_mock_backoff_policy())
543            .with_retry_throttler(get_mock_retry_throttler())
544            .build_token_provider();
545
546        let err = provider.token().await.unwrap_err();
547        assert!(!err.is_transient());
548        server.verify_and_clear();
549        Ok(())
550    }
551
552    #[tokio::test]
553    #[parallel]
554    async fn test_mds_retries_for_success() -> TestResult {
555        let mut server = Server::run();
556        let response = MDSTokenResponse {
557            access_token: "test-access-token".to_string(),
558            expires_in: Some(3600),
559            token_type: "test-token-type".to_string(),
560        };
561
562        server.expect(
563            Expectation::matching(request::path(format!("{MDS_DEFAULT_URI}/token")))
564                .times(3)
565                .respond_with(cycle![
566                    status_code(503).body("try-again"),
567                    status_code(503).body("try-again"),
568                    status_code(200)
569                        .append_header("Content-Type", "application/json")
570                        .body(serde_json::to_string(&response).unwrap()),
571                ]),
572        );
573
574        let provider = Builder::default()
575            .with_endpoint(format!("http://{}", server.addr()))
576            .with_retry_policy(get_mock_auth_retry_policy(3))
577            .with_backoff_policy(get_mock_backoff_policy())
578            .with_retry_throttler(get_mock_retry_throttler())
579            .build_token_provider();
580
581        let token = provider.token().await?;
582        assert_eq!(token.token, "test-access-token");
583
584        server.verify_and_clear();
585        Ok(())
586    }
587
588    #[test]
589    #[parallel]
590    fn validate_default_endpoint_urls() {
591        let default_endpoint_address = Url::parse(&format!("{METADATA_ROOT}{MDS_DEFAULT_URI}"));
592        assert!(
593            default_endpoint_address.is_ok(),
594            "{default_endpoint_address:?}"
595        );
596
597        let token_endpoint_address = Url::parse(&format!("{METADATA_ROOT}{MDS_DEFAULT_URI}/token"));
598        assert!(token_endpoint_address.is_ok(), "{token_endpoint_address:?}");
599    }
600
601    #[tokio::test]
602    #[parallel]
603    async fn headers_success() -> TestResult {
604        let token = Token {
605            token: "test-token".to_string(),
606            token_type: "Bearer".to_string(),
607            expires_at: None,
608            metadata: None,
609        };
610
611        let mut mock = MockTokenProvider::new();
612        mock.expect_token().times(1).return_once(|| Ok(token));
613
614        let mdsc = MDSCredentials {
615            quota_project_id: None,
616            token_provider: TokenCache::new(mock),
617        };
618
619        let mut extensions = Extensions::new();
620        let cached_headers = mdsc.headers(extensions.clone()).await.unwrap();
621        let (headers, entity_tag) = match cached_headers {
622            CacheableResource::New { entity_tag, data } => (data, entity_tag),
623            CacheableResource::NotModified => unreachable!("expecting new headers"),
624        };
625        let token = headers.get(AUTHORIZATION).unwrap();
626        assert_eq!(headers.len(), 1, "{headers:?}");
627        assert_eq!(token, HeaderValue::from_static("Bearer test-token"));
628        assert!(token.is_sensitive());
629
630        extensions.insert(entity_tag);
631
632        let cached_headers = mdsc.headers(extensions).await?;
633
634        match cached_headers {
635            CacheableResource::New { .. } => unreachable!("expecting new headers"),
636            CacheableResource::NotModified => CacheableResource::<HeaderMap>::NotModified,
637        };
638        Ok(())
639    }
640
641    #[tokio::test]
642    #[parallel]
643    async fn access_token_success() -> TestResult {
644        let server = Server::run();
645        let response = MDSTokenResponse {
646            access_token: "test-access-token".to_string(),
647            expires_in: Some(3600),
648            token_type: "Bearer".to_string(),
649        };
650        server.expect(
651            Expectation::matching(all_of![request::path(format!("{MDS_DEFAULT_URI}/token")),])
652                .respond_with(json_encoded(response)),
653        );
654
655        let creds = Builder::default()
656            .with_endpoint(format!("http://{}", server.addr()))
657            .without_access_boundary()
658            .build_access_token_credentials()
659            .unwrap();
660
661        let access_token = creds.access_token().await.unwrap();
662        assert_eq!(access_token.token, "test-access-token");
663
664        Ok(())
665    }
666
667    #[tokio::test]
668    #[parallel]
669    async fn headers_failure() {
670        let mut mock = MockTokenProvider::new();
671        mock.expect_token()
672            .times(1)
673            .return_once(|| Err(errors::non_retryable_from_str("fail")));
674
675        let mdsc = MDSCredentials {
676            quota_project_id: None,
677            token_provider: TokenCache::new(mock),
678        };
679        let result = mdsc.headers(Extensions::new()).await;
680        assert!(result.is_err(), "{result:?}");
681    }
682
683    #[test]
684    #[parallel]
685    fn error_message_with_adc() {
686        let provider = MDSAccessTokenProvider::builder()
687            .created_by_adc(true)
688            .build();
689
690        let want = MDS_NOT_FOUND_ERROR;
691        let got = provider.error_message();
692        assert!(got.contains(want), "{got}, {provider:?}");
693    }
694
695    #[test_case(false, false)]
696    #[test_case(false, true)]
697    #[test_case(true, true)]
698    fn error_message_without_adc(adc: bool, overridden: bool) {
699        let endpoint = if overridden {
700            Some("http://127.0.0.1")
701        } else {
702            None
703        };
704        let provider = MDSAccessTokenProvider::builder()
705            .endpoint(endpoint)
706            .created_by_adc(adc)
707            .build();
708
709        let not_want = MDS_NOT_FOUND_ERROR;
710        let got = provider.error_message();
711        assert!(!got.contains(not_want), "{got}, {provider:?}");
712    }
713
714    #[tokio::test]
715    #[serial]
716    async fn adc_no_mds() -> TestResult {
717        let Err(err) = Builder::from_adc().build_token_provider().token().await else {
718            // The environment has an MDS, skip the test.
719            return Ok(());
720        };
721
722        let original_err = find_source_error::<CredentialsError>(&err).unwrap();
723        assert!(
724            original_err.to_string().contains("application-default"),
725            "display={err}, debug={err:?}"
726        );
727
728        Ok(())
729    }
730
731    #[tokio::test]
732    #[serial]
733    async fn adc_overridden_mds() -> TestResult {
734        let _e = ScopedEnv::set(GCE_METADATA_HOST_ENV_VAR, "metadata.overridden");
735
736        let err = Builder::from_adc()
737            .build_token_provider()
738            .token()
739            .await
740            .unwrap_err();
741
742        let _e = ScopedEnv::remove(GCE_METADATA_HOST_ENV_VAR);
743
744        let original_err = find_source_error::<CredentialsError>(&err).unwrap();
745        assert!(original_err.is_transient());
746        assert!(
747            !original_err.to_string().contains("application-default"),
748            "display={err}, debug={err:?}"
749        );
750        let source = find_source_error::<reqwest::Error>(&err);
751        assert!(matches!(source, Some(e) if e.status().is_none()), "{err:?}");
752
753        Ok(())
754    }
755
756    #[tokio::test]
757    #[serial]
758    async fn builder_no_mds() -> TestResult {
759        let Err(e) = Builder::default().build_token_provider().token().await else {
760            // The environment has an MDS, skip the test.
761            return Ok(());
762        };
763
764        let original_err = find_source_error::<CredentialsError>(&e).unwrap();
765        assert!(
766            !format!("{:?}", original_err.source()).contains("application-default"),
767            "{e:?}"
768        );
769
770        Ok(())
771    }
772
773    #[tokio::test]
774    #[serial]
775    async fn test_gce_metadata_host_env_var() -> TestResult {
776        let server = Server::run();
777        let scopes = ["scope1", "scope2"];
778        let response = MDSTokenResponse {
779            access_token: "test-access-token".to_string(),
780            expires_in: Some(3600),
781            token_type: "test-token-type".to_string(),
782        };
783        server.expect(
784            Expectation::matching(all_of![
785                request::path(format!("{MDS_DEFAULT_URI}/token")),
786                request::query(url_decoded(contains(("scopes", scopes.join(",")))))
787            ])
788            .respond_with(json_encoded(response)),
789        );
790
791        let addr = server.addr().to_string();
792        let _e = ScopedEnv::set(GCE_METADATA_HOST_ENV_VAR, &addr);
793        let mdsc = Builder::default()
794            .with_scopes(["scope1", "scope2"])
795            .without_access_boundary()
796            .build()
797            .unwrap();
798        let headers = mdsc.headers(Extensions::new()).await.unwrap();
799        let _e = ScopedEnv::remove(GCE_METADATA_HOST_ENV_VAR);
800
801        assert_eq!(
802            get_token_from_headers(headers).unwrap(),
803            "test-access-token"
804        );
805        Ok(())
806    }
807
808    #[tokio::test]
809    #[parallel]
810    async fn headers_success_with_quota_project() -> TestResult {
811        let server = Server::run();
812        let scopes = ["scope1", "scope2"];
813        let response = MDSTokenResponse {
814            access_token: "test-access-token".to_string(),
815            expires_in: Some(3600),
816            token_type: "test-token-type".to_string(),
817        };
818        server.expect(
819            Expectation::matching(all_of![
820                request::path(format!("{MDS_DEFAULT_URI}/token")),
821                request::query(url_decoded(contains(("scopes", scopes.join(",")))))
822            ])
823            .respond_with(json_encoded(response)),
824        );
825
826        let mdsc = Builder::default()
827            .with_scopes(["scope1", "scope2"])
828            .with_endpoint(format!("http://{}", server.addr()))
829            .with_quota_project_id("test-project")
830            .without_access_boundary()
831            .build()?;
832
833        let headers = get_headers_from_cache(mdsc.headers(Extensions::new()).await.unwrap())?;
834        let token = headers.get(AUTHORIZATION).unwrap();
835        let quota_project = headers.get(QUOTA_PROJECT_KEY).unwrap();
836
837        assert_eq!(headers.len(), 2, "{headers:?}");
838        assert_eq!(
839            token,
840            HeaderValue::from_static("test-token-type test-access-token")
841        );
842        assert!(token.is_sensitive());
843        assert_eq!(quota_project, HeaderValue::from_static("test-project"));
844        assert!(!quota_project.is_sensitive());
845
846        Ok(())
847    }
848
849    #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
850    #[parallel]
851    async fn token_caching() -> TestResult {
852        let mut server = Server::run();
853        let scopes = vec!["scope1".to_string()];
854        let response = MDSTokenResponse {
855            access_token: "test-access-token".to_string(),
856            expires_in: Some(3600),
857            token_type: "test-token-type".to_string(),
858        };
859        server.expect(
860            Expectation::matching(all_of![
861                request::path(format!("{MDS_DEFAULT_URI}/token")),
862                request::query(url_decoded(contains(("scopes", scopes.join(",")))))
863            ])
864            .times(1)
865            .respond_with(json_encoded(response)),
866        );
867
868        let mdsc = Builder::default()
869            .with_scopes(scopes)
870            .with_endpoint(format!("http://{}", server.addr()))
871            .build()?;
872        let headers = mdsc.headers(Extensions::new()).await?;
873        assert_eq!(
874            get_token_from_headers(headers).unwrap(),
875            "test-access-token"
876        );
877        let headers = mdsc.headers(Extensions::new()).await?;
878        assert_eq!(
879            get_token_from_headers(headers).unwrap(),
880            "test-access-token"
881        );
882
883        // validate that the inner token provider is called only once
884        server.verify_and_clear();
885
886        Ok(())
887    }
888
889    #[tokio::test(start_paused = true)]
890    #[parallel]
891    async fn token_provider_full() -> TestResult {
892        let server = Server::run();
893        let scopes = vec!["scope1".to_string()];
894        let response = MDSTokenResponse {
895            access_token: "test-access-token".to_string(),
896            expires_in: Some(3600),
897            token_type: "test-token-type".to_string(),
898        };
899        server.expect(
900            Expectation::matching(all_of![
901                request::path(format!("{MDS_DEFAULT_URI}/token")),
902                request::query(url_decoded(contains(("scopes", scopes.join(",")))))
903            ])
904            .respond_with(json_encoded(response)),
905        );
906
907        let token = Builder::default()
908            .with_endpoint(format!("http://{}", server.addr()))
909            .with_scopes(scopes)
910            .build_token_provider()
911            .token()
912            .await?;
913
914        let now = tokio::time::Instant::now();
915        assert_eq!(token.token, "test-access-token");
916        assert_eq!(token.token_type, "test-token-type");
917        assert!(
918            token
919                .expires_at
920                .is_some_and(|d| d >= now + Duration::from_secs(3600))
921        );
922
923        Ok(())
924    }
925
926    #[tokio::test(start_paused = true)]
927    #[parallel]
928    async fn token_provider_full_no_scopes() -> TestResult {
929        let server = Server::run();
930        let response = MDSTokenResponse {
931            access_token: "test-access-token".to_string(),
932            expires_in: Some(3600),
933            token_type: "test-token-type".to_string(),
934        };
935        server.expect(
936            Expectation::matching(request::path(format!("{MDS_DEFAULT_URI}/token")))
937                .respond_with(json_encoded(response)),
938        );
939
940        let token = Builder::default()
941            .with_endpoint(format!("http://{}", server.addr()))
942            .build_token_provider()
943            .token()
944            .await?;
945
946        let now = Instant::now();
947        assert_eq!(token.token, "test-access-token");
948        assert_eq!(token.token_type, "test-token-type");
949        assert!(
950            token
951                .expires_at
952                .is_some_and(|d| d == now + Duration::from_secs(3600))
953        );
954
955        Ok(())
956    }
957
958    #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
959    #[parallel]
960    async fn credential_provider_full() -> TestResult {
961        let server = Server::run();
962        let scopes = vec!["scope1".to_string()];
963        let response = MDSTokenResponse {
964            access_token: "test-access-token".to_string(),
965            expires_in: None,
966            token_type: "test-token-type".to_string(),
967        };
968        server.expect(
969            Expectation::matching(all_of![
970                request::path(format!("{MDS_DEFAULT_URI}/token")),
971                request::query(url_decoded(contains(("scopes", scopes.join(",")))))
972            ])
973            .respond_with(json_encoded(response)),
974        );
975
976        let mdsc = Builder::default()
977            .with_endpoint(format!("http://{}", server.addr()))
978            .with_scopes(scopes)
979            .without_access_boundary()
980            .build()?;
981        let headers = mdsc.headers(Extensions::new()).await?;
982        assert_eq!(
983            get_token_from_headers(headers.clone()).unwrap(),
984            "test-access-token"
985        );
986        assert_eq!(
987            get_token_type_from_headers(headers).unwrap(),
988            "test-token-type"
989        );
990
991        Ok(())
992    }
993
994    #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
995    #[parallel]
996    async fn credentials_headers_retryable_error() -> TestResult {
997        let server = Server::run();
998        let scopes = vec!["scope1".to_string()];
999        server.expect(
1000            Expectation::matching(all_of![
1001                request::path(format!("{MDS_DEFAULT_URI}/token")),
1002                request::query(url_decoded(contains(("scopes", scopes.join(",")))))
1003            ])
1004            .respond_with(status_code(503)),
1005        );
1006
1007        let mdsc = Builder::default()
1008            .with_endpoint(format!("http://{}", server.addr()))
1009            .with_scopes(scopes)
1010            .without_access_boundary()
1011            .build()?;
1012        let err = mdsc.headers(Extensions::new()).await.unwrap_err();
1013        let original_err = find_source_error::<CredentialsError>(&err).unwrap();
1014        assert!(original_err.is_transient());
1015        let source = find_source_error::<reqwest::Error>(&err);
1016        assert!(
1017            matches!(source, Some(e) if e.status() == Some(StatusCode::SERVICE_UNAVAILABLE)),
1018            "{err:?}"
1019        );
1020
1021        Ok(())
1022    }
1023
1024    #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
1025    #[parallel]
1026    async fn credentials_headers_nonretryable_error() -> TestResult {
1027        let server = Server::run();
1028        let scopes = vec!["scope1".to_string()];
1029        server.expect(
1030            Expectation::matching(all_of![
1031                request::path(format!("{MDS_DEFAULT_URI}/token")),
1032                request::query(url_decoded(contains(("scopes", scopes.join(",")))))
1033            ])
1034            .respond_with(status_code(401)),
1035        );
1036
1037        let mdsc = Builder::default()
1038            .with_endpoint(format!("http://{}", server.addr()))
1039            .with_scopes(scopes)
1040            .without_access_boundary()
1041            .build()?;
1042
1043        let err = mdsc.headers(Extensions::new()).await.unwrap_err();
1044        let original_err = find_source_error::<CredentialsError>(&err).unwrap();
1045        assert!(!original_err.is_transient());
1046        let source = find_source_error::<reqwest::Error>(&err);
1047        assert!(
1048            matches!(source, Some(e) if e.status() == Some(StatusCode::UNAUTHORIZED)),
1049            "{err:?}"
1050        );
1051
1052        Ok(())
1053    }
1054
1055    #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
1056    #[parallel]
1057    async fn credentials_headers_malformed_response_is_nonretryable() -> TestResult {
1058        let server = Server::run();
1059        let scopes = vec!["scope1".to_string()];
1060        server.expect(
1061            Expectation::matching(all_of![
1062                request::path(format!("{MDS_DEFAULT_URI}/token")),
1063                request::query(url_decoded(contains(("scopes", scopes.join(",")))))
1064            ])
1065            .respond_with(json_encoded("bad json")),
1066        );
1067
1068        let mdsc = Builder::default()
1069            .with_endpoint(format!("http://{}", server.addr()))
1070            .with_scopes(scopes)
1071            .without_access_boundary()
1072            .build()?;
1073
1074        let e = mdsc.headers(Extensions::new()).await.err().unwrap();
1075        assert!(!e.is_transient());
1076
1077        Ok(())
1078    }
1079
1080    #[tokio::test]
1081    #[parallel]
1082    async fn get_default_universe_domain_success() -> TestResult {
1083        let universe_domain_response = Builder::default().build()?.universe_domain().await.unwrap();
1084        assert_eq!(universe_domain_response, DEFAULT_UNIVERSE_DOMAIN);
1085        Ok(())
1086    }
1087
1088    #[tokio::test]
1089    #[parallel]
1090    async fn get_mds_signer() -> TestResult {
1091        let server = Server::run();
1092        server.expect(
1093            Expectation::matching(all_of![request::path(format!("{MDS_DEFAULT_URI}/token")),])
1094                .respond_with(json_encoded(MDSTokenResponse {
1095                    access_token: "test-access-token".to_string(),
1096                    expires_in: None,
1097                    token_type: "Bearer".to_string(),
1098                })),
1099        );
1100        server.expect(
1101            Expectation::matching(all_of![request::path(format!("{MDS_DEFAULT_URI}/email")),])
1102                .respond_with(status_code(200).body("test-client-email")),
1103        );
1104        server.expect(
1105            Expectation::matching(all_of![
1106                request::method_path(
1107                    "POST",
1108                    "/v1/projects/-/serviceAccounts/test-client-email:signBlob"
1109                ),
1110                request::headers(contains(("authorization", "Bearer test-access-token"))),
1111            ])
1112            .respond_with(json_encoded(json!({
1113                "signedBlob": BASE64_STANDARD.encode("signed_blob"),
1114            }))),
1115        );
1116
1117        let endpoint = server.url("").to_string().trim_end_matches('/').to_string();
1118
1119        let signer = Builder::default()
1120            .with_endpoint(&endpoint)
1121            .maybe_iam_endpoint_override(Some(endpoint))
1122            .without_access_boundary()
1123            .build_signer()?;
1124
1125        let client_email = signer.client_email().await?;
1126        assert_eq!(client_email, "test-client-email");
1127
1128        let signature = signer.sign(b"test").await?;
1129        assert_eq!(signature.as_ref(), b"signed_blob");
1130
1131        Ok(())
1132    }
1133
1134    #[tokio::test]
1135    #[parallel]
1136    #[cfg(google_cloud_unstable_trusted_boundaries)]
1137    async fn e2e_access_boundary() -> TestResult {
1138        use crate::credentials::tests::get_access_boundary_from_headers;
1139
1140        let server = Server::run();
1141        server.expect(
1142            Expectation::matching(all_of![request::path(format!("{MDS_DEFAULT_URI}/token")),])
1143                .respond_with(json_encoded(MDSTokenResponse {
1144                    access_token: "test-access-token".to_string(),
1145                    expires_in: None,
1146                    token_type: "Bearer".to_string(),
1147                })),
1148        );
1149        server.expect(
1150            Expectation::matching(all_of![request::path(format!("{MDS_DEFAULT_URI}/email")),])
1151                .respond_with(status_code(200).body("test-client-email")),
1152        );
1153        server.expect(
1154            Expectation::matching(all_of![
1155                request::method_path(
1156                    "GET",
1157                    "/v1/projects/-/serviceAccounts/test-client-email/allowedLocations"
1158                ),
1159                request::headers(contains(("authorization", "Bearer test-access-token"))),
1160            ])
1161            .respond_with(json_encoded(json!({
1162                "locations": ["us-central1", "us-east1"],
1163                "encodedLocations": "0x1234"
1164            }))),
1165        );
1166
1167        let endpoint = server.url("").to_string().trim_end_matches('/').to_string();
1168
1169        let creds = Builder::default()
1170            .with_endpoint(&endpoint)
1171            .maybe_iam_endpoint_override(Some(endpoint))
1172            .build_credentials()?;
1173
1174        // let the access boundary background thread update
1175        creds.wait_for_boundary().await;
1176
1177        let headers = creds.headers(Extensions::new()).await?;
1178        let token = get_token_from_headers(headers.clone());
1179        let access_boundary = get_access_boundary_from_headers(headers);
1180        assert!(token.is_some(), "should have some token: {token:?}");
1181        assert_eq!(
1182            access_boundary.as_deref(),
1183            Some("0x1234"),
1184            "should be 0x1234 but found: {access_boundary:?}"
1185        );
1186
1187        Ok(())
1188    }
1189}