google_cloud_auth/credentials/
mds.rs

1// Copyright 2024 Google LLC
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     https://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15//! [Metadata Service] Credentials type.
16//!
17//! Google Cloud environments such as [Google Compute Engine (GCE)][gce-link],
18//! [Google Kubernetes Engine (GKE)][gke-link], or [Cloud Run] provide a metadata service.
19//! This is a local service to the VM (or pod) which (as the name implies) provides
20//! metadata information about the VM. The service also provides access
21//! tokens associated with the [default service account] for the corresponding
22//! VM.
23//!
24//! The default host name of the metadata service is `metadata.google.internal`.
25//! If you would like to use a different hostname, you can set it using the
26//! `GCE_METADATA_HOST` environment variable.
27//!
28//! You can use this access token to securely authenticate with Google Cloud,
29//! without having to download secrets or other credentials. The types in this
30//! module allow you to retrieve these access tokens, and can be used with
31//! the Google Cloud client libraries for Rust.
32//!
33//! ## Example: Creating credentials with a custom quota project
34//!
35//! ```
36//! # use google_cloud_auth::credentials::mds::Builder;
37//! # use google_cloud_auth::credentials::Credentials;
38//! # use http::Extensions;
39//! # tokio_test::block_on(async {
40//! let credentials: Credentials = Builder::default()
41//!     .with_quota_project_id("my-quota-project")
42//!     .build()?;
43//! let headers = credentials.headers(Extensions::new()).await?;
44//! println!("Headers: {headers:?}");
45//! # Ok::<(), anyhow::Error>(())
46//! # });
47//! ```
48//!
49//! ## Example: Creating credentials with custom retry behavior
50//!
51//! ```
52//! # use google_cloud_auth::credentials::mds::Builder;
53//! # use google_cloud_auth::credentials::Credentials;
54//! # use http::Extensions;
55//! # use std::time::Duration;
56//! # tokio_test::block_on(async {
57//! use gax::retry_policy::{AlwaysRetry, RetryPolicyExt};
58//! use gax::exponential_backoff::ExponentialBackoff;
59//! let backoff = ExponentialBackoff::default();
60//! let credentials: Credentials = Builder::default()
61//!     .with_retry_policy(AlwaysRetry.with_attempt_limit(3))
62//!     .with_backoff_policy(backoff)
63//!     .build()?;
64//! let headers = credentials.headers(Extensions::new()).await?;
65//! println!("Headers: {headers:?}");
66//! # Ok::<(), anyhow::Error>(())
67//! # });
68//! ```
69//!
70//! [Application Default Credentials]: https://cloud.google.com/docs/authentication/application-default-credentials
71//! [Cloud Run]: https://cloud.google.com/run
72//! [default service account]: https://cloud.google.com/iam/docs/service-account-types#default
73//! [gce-link]: https://cloud.google.com/products/compute
74//! [gke-link]: https://cloud.google.com/kubernetes-engine
75//! [Metadata Service]: https://cloud.google.com/compute/docs/metadata/overview
76
77use crate::credentials::dynamic::{AccessTokenCredentialsProvider, CredentialsProvider};
78use crate::credentials::{AccessToken, AccessTokenCredentials, CacheableResource, Credentials};
79use crate::errors::CredentialsError;
80use crate::headers_util::build_cacheable_headers;
81use crate::retry::{Builder as RetryTokenProviderBuilder, TokenProviderWithRetry};
82use crate::token::{CachedTokenProvider, Token, TokenProvider};
83use crate::token_cache::TokenCache;
84use crate::{BuildResult, Result};
85use async_trait::async_trait;
86use bon::Builder;
87use gax::backoff_policy::BackoffPolicyArg;
88use gax::retry_policy::RetryPolicyArg;
89use gax::retry_throttler::RetryThrottlerArg;
90use http::{Extensions, HeaderMap, HeaderValue};
91use reqwest::Client;
92use std::default::Default;
93use std::sync::Arc;
94use std::time::Duration;
95use tokio::time::Instant;
96
97pub(crate) const METADATA_FLAVOR_VALUE: &str = "Google";
98pub(crate) const METADATA_FLAVOR: &str = "metadata-flavor";
99pub(crate) const METADATA_ROOT: &str = "http://metadata.google.internal";
100pub(crate) const MDS_DEFAULT_URI: &str = "/computeMetadata/v1/instance/service-accounts/default";
101pub(crate) const GCE_METADATA_HOST_ENV_VAR: &str = "GCE_METADATA_HOST";
102// TODO(#2235) - Improve this message by talking about retries when really running with MDS
103const MDS_NOT_FOUND_ERROR: &str = concat!(
104    "Could not fetch an auth token to authenticate with Google Cloud. ",
105    "The most common reason for this problem is that you are not running in a Google Cloud Environment ",
106    "and you have not configured local credentials for development and testing. ",
107    "To setup local credentials, run `gcloud auth application-default login`. ",
108    "More information on how to authenticate client libraries can be found at https://cloud.google.com/docs/authentication/client-libraries"
109);
110
111#[derive(Debug)]
112struct MDSCredentials<T>
113where
114    T: CachedTokenProvider,
115{
116    quota_project_id: Option<String>,
117    token_provider: T,
118}
119
120/// Creates [Credentials] instances backed by the [Metadata Service].
121///
122/// While the Google Cloud client libraries for Rust default to credentials
123/// backed by the metadata service, some applications may need to:
124/// * Customize the metadata service credentials in some way
125/// * Bypass the [Application Default Credentials] lookup and only
126///   use the metadata server credentials
127/// * Use the credentials directly outside the client libraries
128///
129/// [Application Default Credentials]: https://cloud.google.com/docs/authentication/application-default-credentials
130/// [Metadata Service]: https://cloud.google.com/compute/docs/metadata/overview
131#[derive(Debug, Default)]
132pub struct Builder {
133    endpoint: Option<String>,
134    quota_project_id: Option<String>,
135    scopes: Option<Vec<String>>,
136    created_by_adc: bool,
137    retry_builder: RetryTokenProviderBuilder,
138}
139
140impl Builder {
141    /// Sets the endpoint for this credentials.
142    ///
143    /// A trailing slash is significant, so specify the base URL without a trailing  
144    /// slash. If not set, the credentials use `http://metadata.google.internal`.
145    ///
146    /// # Example
147    /// ```
148    /// # use google_cloud_auth::credentials::mds::Builder;
149    /// # tokio_test::block_on(async {
150    /// let credentials = Builder::default()
151    ///     .with_endpoint("https://metadata.google.foobar")
152    ///     .build();
153    /// # });
154    /// ```
155    pub fn with_endpoint<S: Into<String>>(mut self, endpoint: S) -> Self {
156        self.endpoint = Some(endpoint.into());
157        self
158    }
159
160    /// Set the [quota project] for this credentials.
161    ///
162    /// In some services, you can use a service account in
163    /// one project for authentication and authorization, and charge
164    /// the usage to a different project. This may require that the
165    /// service account has `serviceusage.services.use` permissions on the quota project.
166    ///
167    /// [quota project]: https://cloud.google.com/docs/quotas/quota-project
168    pub fn with_quota_project_id<S: Into<String>>(mut self, quota_project_id: S) -> Self {
169        self.quota_project_id = Some(quota_project_id.into());
170        self
171    }
172
173    /// Sets the [scopes] for this credentials.
174    ///
175    /// Metadata server issues tokens based on the requested scopes.
176    /// If no scopes are specified, the credentials defaults to all
177    /// scopes configured for the [default service account] on the instance.
178    ///
179    /// [default service account]: https://cloud.google.com/iam/docs/service-account-types#default
180    /// [scopes]: https://developers.google.com/identity/protocols/oauth2/scopes
181    pub fn with_scopes<I, S>(mut self, scopes: I) -> Self
182    where
183        I: IntoIterator<Item = S>,
184        S: Into<String>,
185    {
186        self.scopes = Some(scopes.into_iter().map(|s| s.into()).collect());
187        self
188    }
189
190    /// Configure the retry policy for fetching tokens.
191    ///
192    /// The retry policy controls how to handle retries, and sets limits on
193    /// the number of attempts or the total time spent retrying.
194    ///
195    /// ```
196    /// # use google_cloud_auth::credentials::mds::Builder;
197    /// # tokio_test::block_on(async {
198    /// use gax::retry_policy::{AlwaysRetry, RetryPolicyExt};
199    /// let credentials = Builder::default()
200    ///     .with_retry_policy(AlwaysRetry.with_attempt_limit(3))
201    ///     .build();
202    /// # });
203    /// ```
204    pub fn with_retry_policy<V: Into<RetryPolicyArg>>(mut self, v: V) -> Self {
205        self.retry_builder = self.retry_builder.with_retry_policy(v.into());
206        self
207    }
208
209    /// Configure the retry backoff policy.
210    ///
211    /// The backoff policy controls how long to wait in between retry attempts.
212    ///
213    /// ```
214    /// # use google_cloud_auth::credentials::mds::Builder;
215    /// # use std::time::Duration;
216    /// # tokio_test::block_on(async {
217    /// use gax::exponential_backoff::ExponentialBackoff;
218    /// let policy = ExponentialBackoff::default();
219    /// let credentials = Builder::default()
220    ///     .with_backoff_policy(policy)
221    ///     .build();
222    /// # });
223    /// ```
224    pub fn with_backoff_policy<V: Into<BackoffPolicyArg>>(mut self, v: V) -> Self {
225        self.retry_builder = self.retry_builder.with_backoff_policy(v.into());
226        self
227    }
228
229    /// Configure the retry throttler.
230    ///
231    /// Advanced applications may want to configure a retry throttler to
232    /// [Address Cascading Failures] and when [Handling Overload] conditions.
233    /// The authentication library throttles its retry loop, using a policy to
234    /// control the throttling algorithm. Use this method to fine tune or
235    /// customize the default retry throttler.
236    ///
237    /// [Handling Overload]: https://sre.google/sre-book/handling-overload/
238    /// [Address Cascading Failures]: https://sre.google/sre-book/addressing-cascading-failures/
239    ///
240    /// ```
241    /// # use google_cloud_auth::credentials::mds::Builder;
242    /// # tokio_test::block_on(async {
243    /// use gax::retry_throttler::AdaptiveThrottler;
244    /// let credentials = Builder::default()
245    ///     .with_retry_throttler(AdaptiveThrottler::default())
246    ///     .build();
247    /// # });
248    /// ```
249    pub fn with_retry_throttler<V: Into<RetryThrottlerArg>>(mut self, v: V) -> Self {
250        self.retry_builder = self.retry_builder.with_retry_throttler(v.into());
251        self
252    }
253
254    // This method is used to build mds credentials from ADC
255    pub(crate) fn from_adc() -> Self {
256        Self {
257            created_by_adc: true,
258            ..Default::default()
259        }
260    }
261
262    fn build_token_provider(self) -> TokenProviderWithRetry<MDSAccessTokenProvider> {
263        let final_endpoint: String;
264        let endpoint_overridden: bool;
265
266        // Determine the endpoint and whether it was overridden
267        if let Ok(host_from_env) = std::env::var(GCE_METADATA_HOST_ENV_VAR) {
268            // Check GCE_METADATA_HOST environment variable first
269            final_endpoint = format!("http://{host_from_env}");
270            endpoint_overridden = true;
271        } else if let Some(builder_endpoint) = self.endpoint {
272            // Else, check if an endpoint was provided to the mds::Builder
273            final_endpoint = builder_endpoint;
274            endpoint_overridden = true;
275        } else {
276            // Else, use the default metadata root
277            final_endpoint = METADATA_ROOT.to_string();
278            endpoint_overridden = false;
279        };
280
281        let tp = MDSAccessTokenProvider::builder()
282            .endpoint(final_endpoint)
283            .maybe_scopes(self.scopes)
284            .endpoint_overridden(endpoint_overridden)
285            .created_by_adc(self.created_by_adc)
286            .build();
287        self.retry_builder.build(tp)
288    }
289
290    /// Returns a [Credentials] instance with the configured settings.
291    pub fn build(self) -> BuildResult<Credentials> {
292        Ok(self.build_access_token_credentials()?.into())
293    }
294
295    /// Returns an [AccessTokenCredentials] instance with the configured settings.
296    ///
297    /// # Example
298    ///
299    /// ```
300    /// # use google_cloud_auth::credentials::mds::Builder;
301    /// # use google_cloud_auth::credentials::{AccessTokenCredentials, AccessTokenCredentialsProvider};
302    /// # tokio_test::block_on(async {
303    /// let credentials: AccessTokenCredentials = Builder::default()
304    ///     .with_quota_project_id("my-quota-project")
305    ///     .build_access_token_credentials()?;
306    /// let access_token = credentials.access_token().await?;
307    /// println!("Token: {}", access_token.token);
308    /// # Ok::<(), anyhow::Error>(())
309    /// # });
310    /// ```
311    pub fn build_access_token_credentials(self) -> BuildResult<AccessTokenCredentials> {
312        let mdsc = MDSCredentials {
313            quota_project_id: self.quota_project_id.clone(),
314            token_provider: TokenCache::new(self.build_token_provider()),
315        };
316        Ok(AccessTokenCredentials {
317            inner: Arc::new(mdsc),
318        })
319    }
320}
321
322#[async_trait::async_trait]
323impl<T> CredentialsProvider for MDSCredentials<T>
324where
325    T: CachedTokenProvider,
326{
327    async fn headers(&self, extensions: Extensions) -> Result<CacheableResource<HeaderMap>> {
328        let cached_token = self.token_provider.token(extensions).await?;
329        build_cacheable_headers(&cached_token, &self.quota_project_id)
330    }
331}
332
333#[async_trait::async_trait]
334impl<T> AccessTokenCredentialsProvider for MDSCredentials<T>
335where
336    T: CachedTokenProvider,
337{
338    async fn access_token(&self) -> Result<AccessToken> {
339        let token = self.token_provider.token(Extensions::new()).await?;
340        token.into()
341    }
342}
343
344#[derive(Clone, Debug, PartialEq, serde::Deserialize, serde::Serialize)]
345struct MDSTokenResponse {
346    access_token: String,
347    #[serde(skip_serializing_if = "Option::is_none")]
348    expires_in: Option<u64>,
349    token_type: String,
350}
351
352#[derive(Debug, Clone, Default, Builder)]
353struct MDSAccessTokenProvider {
354    #[builder(into)]
355    scopes: Option<Vec<String>>,
356    #[builder(into)]
357    endpoint: String,
358    endpoint_overridden: bool,
359    created_by_adc: bool,
360}
361
362impl MDSAccessTokenProvider {
363    // During ADC, if no credentials are found in the well-known location and the GOOGLE_APPLICATION_CREDENTIALS
364    // environment variable is not set, we default to MDS credentials without checking if the code is really
365    // running in an environment with MDS. To help users who got to this state because of lack of credentials
366    // setup on their machines, we provide a detailed error message to them talking about local setup and other
367    // auth mechanisms available to them.
368    // If the endpoint is overridden, even if ADC was used to create the MDS credentials, we do not give a detailed
369    // error message because they deliberately wanted to use an MDS.
370    fn error_message(&self) -> &str {
371        if self.use_adc_message() {
372            MDS_NOT_FOUND_ERROR
373        } else {
374            "failed to fetch token"
375        }
376    }
377
378    fn use_adc_message(&self) -> bool {
379        self.created_by_adc && !self.endpoint_overridden
380    }
381}
382
383#[async_trait]
384impl TokenProvider for MDSAccessTokenProvider {
385    async fn token(&self) -> Result<Token> {
386        let client = Client::new();
387        let request = client
388            .get(format!("{}{}/token", self.endpoint, MDS_DEFAULT_URI))
389            .header(
390                METADATA_FLAVOR,
391                HeaderValue::from_static(METADATA_FLAVOR_VALUE),
392            );
393        // Use the `scopes` option if set, otherwise let the MDS use the default
394        // scopes.
395        let scopes = self.scopes.as_ref().map(|v| v.join(","));
396        let request = scopes
397            .into_iter()
398            .fold(request, |r, s| r.query(&[("scopes", s)]));
399
400        // If the connection to MDS was not successful, it is useful to retry when really
401        // running on MDS environments and not useful if there is no MDS. We will mark the error
402        // as retryable and let the retry policy determine whether to retry or not. Whenever we
403        // define a default retry policy, we can skip retrying this case.
404        let response = request
405            .send()
406            .await
407            .map_err(|e| crate::errors::from_http_error(e, self.error_message()))?;
408        // Process the response
409        if !response.status().is_success() {
410            let err = crate::errors::from_http_response(response, self.error_message()).await;
411            return Err(err);
412        }
413        let response = response.json::<MDSTokenResponse>().await.map_err(|e| {
414            // Decoding errors are not transient. Typically they indicate a badly
415            // configured MDS endpoint, or DNS redirecting the request to a random
416            // server, e.g., ISPs that redirect unknown services to HTTP.
417            CredentialsError::from_source(!e.is_decode(), e)
418        })?;
419        let token = Token {
420            token: response.access_token,
421            token_type: response.token_type,
422            expires_at: response
423                .expires_in
424                .map(|d| Instant::now() + Duration::from_secs(d)),
425            metadata: None,
426        };
427        Ok(token)
428    }
429}
430
431#[cfg(test)]
432mod tests {
433    use super::*;
434    use crate::credentials::DEFAULT_UNIVERSE_DOMAIN;
435    use crate::credentials::QUOTA_PROJECT_KEY;
436    use crate::credentials::tests::{
437        find_source_error, get_headers_from_cache, get_mock_auth_retry_policy,
438        get_mock_backoff_policy, get_mock_retry_throttler, get_token_from_headers,
439        get_token_type_from_headers,
440    };
441    use crate::errors;
442    use crate::errors::CredentialsError;
443    use crate::token::tests::MockTokenProvider;
444    use http::HeaderValue;
445    use http::header::AUTHORIZATION;
446    use httptest::cycle;
447    use httptest::matchers::{all_of, contains, request, url_decoded};
448    use httptest::responders::{json_encoded, status_code};
449    use httptest::{Expectation, Server};
450    use reqwest::StatusCode;
451    use scoped_env::ScopedEnv;
452    use serial_test::{parallel, serial};
453    use std::error::Error;
454    use test_case::test_case;
455    use url::Url;
456
457    type TestResult = anyhow::Result<()>;
458
459    #[tokio::test]
460    #[parallel]
461    async fn test_mds_retries_on_transient_failures() -> TestResult {
462        let mut server = Server::run();
463        server.expect(
464            Expectation::matching(request::path(format!("{MDS_DEFAULT_URI}/token")))
465                .times(3)
466                .respond_with(status_code(503)),
467        );
468
469        let provider = Builder::default()
470            .with_endpoint(format!("http://{}", server.addr()))
471            .with_retry_policy(get_mock_auth_retry_policy(3))
472            .with_backoff_policy(get_mock_backoff_policy())
473            .with_retry_throttler(get_mock_retry_throttler())
474            .build_token_provider();
475
476        let err = provider.token().await.unwrap_err();
477        assert!(!err.is_transient());
478        server.verify_and_clear();
479        Ok(())
480    }
481
482    #[tokio::test]
483    #[parallel]
484    async fn test_mds_does_not_retry_on_non_transient_failures() -> TestResult {
485        let mut server = Server::run();
486        server.expect(
487            Expectation::matching(request::path(format!("{MDS_DEFAULT_URI}/token")))
488                .times(1)
489                .respond_with(status_code(401)),
490        );
491
492        let provider = Builder::default()
493            .with_endpoint(format!("http://{}", server.addr()))
494            .with_retry_policy(get_mock_auth_retry_policy(1))
495            .with_backoff_policy(get_mock_backoff_policy())
496            .with_retry_throttler(get_mock_retry_throttler())
497            .build_token_provider();
498
499        let err = provider.token().await.unwrap_err();
500        assert!(!err.is_transient());
501        server.verify_and_clear();
502        Ok(())
503    }
504
505    #[tokio::test]
506    #[parallel]
507    async fn test_mds_retries_for_success() -> TestResult {
508        let mut server = Server::run();
509        let response = MDSTokenResponse {
510            access_token: "test-access-token".to_string(),
511            expires_in: Some(3600),
512            token_type: "test-token-type".to_string(),
513        };
514
515        server.expect(
516            Expectation::matching(request::path(format!("{MDS_DEFAULT_URI}/token")))
517                .times(3)
518                .respond_with(cycle![
519                    status_code(503).body("try-again"),
520                    status_code(503).body("try-again"),
521                    status_code(200)
522                        .append_header("Content-Type", "application/json")
523                        .body(serde_json::to_string(&response).unwrap()),
524                ]),
525        );
526
527        let provider = Builder::default()
528            .with_endpoint(format!("http://{}", server.addr()))
529            .with_retry_policy(get_mock_auth_retry_policy(3))
530            .with_backoff_policy(get_mock_backoff_policy())
531            .with_retry_throttler(get_mock_retry_throttler())
532            .build_token_provider();
533
534        let token = provider.token().await?;
535        assert_eq!(token.token, "test-access-token");
536
537        server.verify_and_clear();
538        Ok(())
539    }
540
541    #[test]
542    fn validate_default_endpoint_urls() {
543        let default_endpoint_address = Url::parse(&format!("{METADATA_ROOT}{MDS_DEFAULT_URI}"));
544        assert!(default_endpoint_address.is_ok());
545
546        let token_endpoint_address = Url::parse(&format!("{METADATA_ROOT}{MDS_DEFAULT_URI}/token"));
547        assert!(token_endpoint_address.is_ok());
548    }
549
550    #[tokio::test]
551    async fn headers_success() -> TestResult {
552        let token = Token {
553            token: "test-token".to_string(),
554            token_type: "Bearer".to_string(),
555            expires_at: None,
556            metadata: None,
557        };
558
559        let mut mock = MockTokenProvider::new();
560        mock.expect_token().times(1).return_once(|| Ok(token));
561
562        let mdsc = MDSCredentials {
563            quota_project_id: None,
564            token_provider: TokenCache::new(mock),
565        };
566
567        let mut extensions = Extensions::new();
568        let cached_headers = mdsc.headers(extensions.clone()).await.unwrap();
569        let (headers, entity_tag) = match cached_headers {
570            CacheableResource::New { entity_tag, data } => (data, entity_tag),
571            CacheableResource::NotModified => unreachable!("expecting new headers"),
572        };
573        let token = headers.get(AUTHORIZATION).unwrap();
574        assert_eq!(headers.len(), 1, "{headers:?}");
575        assert_eq!(token, HeaderValue::from_static("Bearer test-token"));
576        assert!(token.is_sensitive());
577
578        extensions.insert(entity_tag);
579
580        let cached_headers = mdsc.headers(extensions).await?;
581
582        match cached_headers {
583            CacheableResource::New { .. } => unreachable!("expecting new headers"),
584            CacheableResource::NotModified => CacheableResource::<HeaderMap>::NotModified,
585        };
586        Ok(())
587    }
588
589    #[tokio::test]
590    async fn access_token_success() -> TestResult {
591        let server = Server::run();
592        let response = MDSTokenResponse {
593            access_token: "test-access-token".to_string(),
594            expires_in: Some(3600),
595            token_type: "Bearer".to_string(),
596        };
597        server.expect(
598            Expectation::matching(all_of![request::path(format!("{MDS_DEFAULT_URI}/token")),])
599                .respond_with(json_encoded(response)),
600        );
601
602        let creds = Builder::default()
603            .with_endpoint(format!("http://{}", server.addr()))
604            .build_access_token_credentials()
605            .unwrap();
606
607        let access_token = creds.access_token().await.unwrap();
608        assert_eq!(access_token.token, "test-access-token");
609
610        Ok(())
611    }
612
613    #[tokio::test]
614    async fn headers_failure() {
615        let mut mock = MockTokenProvider::new();
616        mock.expect_token()
617            .times(1)
618            .return_once(|| Err(errors::non_retryable_from_str("fail")));
619
620        let mdsc = MDSCredentials {
621            quota_project_id: None,
622            token_provider: TokenCache::new(mock),
623        };
624        assert!(mdsc.headers(Extensions::new()).await.is_err());
625    }
626
627    #[test]
628    fn error_message_with_adc() {
629        let provider = MDSAccessTokenProvider::builder()
630            .endpoint("http://127.0.0.1")
631            .created_by_adc(true)
632            .endpoint_overridden(false)
633            .build();
634
635        let want = MDS_NOT_FOUND_ERROR;
636        let got = provider.error_message();
637        assert!(got.contains(want), "{got}, {provider:?}");
638    }
639
640    #[test_case(false, false)]
641    #[test_case(false, true)]
642    #[test_case(true, true)]
643    fn error_message_without_adc(adc: bool, overridden: bool) {
644        let provider = MDSAccessTokenProvider::builder()
645            .endpoint("http://127.0.0.1")
646            .created_by_adc(adc)
647            .endpoint_overridden(overridden)
648            .build();
649
650        let not_want = MDS_NOT_FOUND_ERROR;
651        let got = provider.error_message();
652        assert!(!got.contains(not_want), "{got}, {provider:?}");
653    }
654
655    #[tokio::test]
656    #[serial]
657    async fn adc_no_mds() -> TestResult {
658        let Err(err) = Builder::from_adc().build_token_provider().token().await else {
659            // The environment has an MDS, skip the test.
660            return Ok(());
661        };
662
663        let original_err = find_source_error::<CredentialsError>(&err).unwrap();
664        assert!(
665            original_err.to_string().contains("application-default"),
666            "display={err}, debug={err:?}"
667        );
668
669        Ok(())
670    }
671
672    #[tokio::test]
673    #[serial]
674    async fn adc_overridden_mds() -> TestResult {
675        let _e = ScopedEnv::set(super::GCE_METADATA_HOST_ENV_VAR, "metadata.overridden");
676
677        let err = Builder::from_adc()
678            .build_token_provider()
679            .token()
680            .await
681            .unwrap_err();
682
683        let _e = ScopedEnv::remove(super::GCE_METADATA_HOST_ENV_VAR);
684
685        let original_err = find_source_error::<CredentialsError>(&err).unwrap();
686        assert!(original_err.is_transient());
687        assert!(
688            !original_err.to_string().contains("application-default"),
689            "display={err}, debug={err:?}"
690        );
691        let source = find_source_error::<reqwest::Error>(&err);
692        assert!(matches!(source, Some(e) if e.status().is_none()), "{err:?}");
693
694        Ok(())
695    }
696
697    #[tokio::test]
698    #[serial]
699    async fn builder_no_mds() -> TestResult {
700        let Err(e) = Builder::default().build_token_provider().token().await else {
701            // The environment has an MDS, skip the test.
702            return Ok(());
703        };
704
705        let original_err = find_source_error::<CredentialsError>(&e).unwrap();
706        assert!(
707            !format!("{:?}", original_err.source()).contains("application-default"),
708            "{e:?}"
709        );
710
711        Ok(())
712    }
713
714    #[tokio::test]
715    #[serial]
716    async fn test_gce_metadata_host_env_var() -> TestResult {
717        let server = Server::run();
718        let scopes = ["scope1", "scope2"];
719        let response = MDSTokenResponse {
720            access_token: "test-access-token".to_string(),
721            expires_in: Some(3600),
722            token_type: "test-token-type".to_string(),
723        };
724        server.expect(
725            Expectation::matching(all_of![
726                request::path(format!("{MDS_DEFAULT_URI}/token")),
727                request::query(url_decoded(contains(("scopes", scopes.join(",")))))
728            ])
729            .respond_with(json_encoded(response)),
730        );
731
732        let addr = server.addr().to_string();
733        let _e = ScopedEnv::set(super::GCE_METADATA_HOST_ENV_VAR, &addr);
734        let mdsc = Builder::default()
735            .with_scopes(["scope1", "scope2"])
736            .build()
737            .unwrap();
738        let headers = mdsc.headers(Extensions::new()).await.unwrap();
739        let _e = ScopedEnv::remove(super::GCE_METADATA_HOST_ENV_VAR);
740
741        assert_eq!(
742            get_token_from_headers(headers).unwrap(),
743            "test-access-token"
744        );
745        Ok(())
746    }
747
748    #[tokio::test]
749    #[parallel]
750    async fn headers_success_with_quota_project() -> TestResult {
751        let server = Server::run();
752        let scopes = ["scope1", "scope2"];
753        let response = MDSTokenResponse {
754            access_token: "test-access-token".to_string(),
755            expires_in: Some(3600),
756            token_type: "test-token-type".to_string(),
757        };
758        server.expect(
759            Expectation::matching(all_of![
760                request::path(format!("{MDS_DEFAULT_URI}/token")),
761                request::query(url_decoded(contains(("scopes", scopes.join(",")))))
762            ])
763            .respond_with(json_encoded(response)),
764        );
765
766        let mdsc = Builder::default()
767            .with_scopes(["scope1", "scope2"])
768            .with_endpoint(format!("http://{}", server.addr()))
769            .with_quota_project_id("test-project")
770            .build()?;
771
772        let headers = get_headers_from_cache(mdsc.headers(Extensions::new()).await.unwrap())?;
773        let token = headers.get(AUTHORIZATION).unwrap();
774        let quota_project = headers.get(QUOTA_PROJECT_KEY).unwrap();
775
776        assert_eq!(headers.len(), 2, "{headers:?}");
777        assert_eq!(
778            token,
779            HeaderValue::from_static("test-token-type test-access-token")
780        );
781        assert!(token.is_sensitive());
782        assert_eq!(quota_project, HeaderValue::from_static("test-project"));
783        assert!(!quota_project.is_sensitive());
784
785        Ok(())
786    }
787
788    #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
789    #[parallel]
790    async fn token_caching() -> TestResult {
791        let mut server = Server::run();
792        let scopes = vec!["scope1".to_string()];
793        let response = MDSTokenResponse {
794            access_token: "test-access-token".to_string(),
795            expires_in: Some(3600),
796            token_type: "test-token-type".to_string(),
797        };
798        server.expect(
799            Expectation::matching(all_of![
800                request::path(format!("{MDS_DEFAULT_URI}/token")),
801                request::query(url_decoded(contains(("scopes", scopes.join(",")))))
802            ])
803            .times(1)
804            .respond_with(json_encoded(response)),
805        );
806
807        let mdsc = Builder::default()
808            .with_scopes(scopes)
809            .with_endpoint(format!("http://{}", server.addr()))
810            .build()?;
811        let headers = mdsc.headers(Extensions::new()).await?;
812        assert_eq!(
813            get_token_from_headers(headers).unwrap(),
814            "test-access-token"
815        );
816        let headers = mdsc.headers(Extensions::new()).await?;
817        assert_eq!(
818            get_token_from_headers(headers).unwrap(),
819            "test-access-token"
820        );
821
822        // validate that the inner token provider is called only once
823        server.verify_and_clear();
824
825        Ok(())
826    }
827
828    #[tokio::test(start_paused = true)]
829    #[parallel]
830    async fn token_provider_full() -> TestResult {
831        let server = Server::run();
832        let scopes = vec!["scope1".to_string()];
833        let response = MDSTokenResponse {
834            access_token: "test-access-token".to_string(),
835            expires_in: Some(3600),
836            token_type: "test-token-type".to_string(),
837        };
838        server.expect(
839            Expectation::matching(all_of![
840                request::path(format!("{MDS_DEFAULT_URI}/token")),
841                request::query(url_decoded(contains(("scopes", scopes.join(",")))))
842            ])
843            .respond_with(json_encoded(response)),
844        );
845
846        let token = Builder::default()
847            .with_endpoint(format!("http://{}", server.addr()))
848            .with_scopes(scopes)
849            .build_token_provider()
850            .token()
851            .await?;
852
853        let now = tokio::time::Instant::now();
854        assert_eq!(token.token, "test-access-token");
855        assert_eq!(token.token_type, "test-token-type");
856        assert!(
857            token
858                .expires_at
859                .is_some_and(|d| d >= now + Duration::from_secs(3600))
860        );
861
862        Ok(())
863    }
864
865    #[tokio::test(start_paused = true)]
866    #[parallel]
867    async fn token_provider_full_no_scopes() -> TestResult {
868        let server = Server::run();
869        let response = MDSTokenResponse {
870            access_token: "test-access-token".to_string(),
871            expires_in: Some(3600),
872            token_type: "test-token-type".to_string(),
873        };
874        server.expect(
875            Expectation::matching(request::path(format!("{MDS_DEFAULT_URI}/token")))
876                .respond_with(json_encoded(response)),
877        );
878
879        let token = Builder::default()
880            .with_endpoint(format!("http://{}", server.addr()))
881            .build_token_provider()
882            .token()
883            .await?;
884
885        let now = Instant::now();
886        assert_eq!(token.token, "test-access-token");
887        assert_eq!(token.token_type, "test-token-type");
888        assert!(
889            token
890                .expires_at
891                .is_some_and(|d| d == now + Duration::from_secs(3600))
892        );
893
894        Ok(())
895    }
896
897    #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
898    #[parallel]
899    async fn credential_provider_full() -> TestResult {
900        let server = Server::run();
901        let scopes = vec!["scope1".to_string()];
902        let response = MDSTokenResponse {
903            access_token: "test-access-token".to_string(),
904            expires_in: None,
905            token_type: "test-token-type".to_string(),
906        };
907        server.expect(
908            Expectation::matching(all_of![
909                request::path(format!("{MDS_DEFAULT_URI}/token")),
910                request::query(url_decoded(contains(("scopes", scopes.join(",")))))
911            ])
912            .respond_with(json_encoded(response)),
913        );
914
915        let mdsc = Builder::default()
916            .with_endpoint(format!("http://{}", server.addr()))
917            .with_scopes(scopes)
918            .build()?;
919        let headers = mdsc.headers(Extensions::new()).await?;
920        assert_eq!(
921            get_token_from_headers(headers.clone()).unwrap(),
922            "test-access-token"
923        );
924        assert_eq!(
925            get_token_type_from_headers(headers).unwrap(),
926            "test-token-type"
927        );
928
929        Ok(())
930    }
931
932    #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
933    #[parallel]
934    async fn credentials_headers_retryable_error() -> TestResult {
935        let server = Server::run();
936        let scopes = vec!["scope1".to_string()];
937        server.expect(
938            Expectation::matching(all_of![
939                request::path(format!("{MDS_DEFAULT_URI}/token")),
940                request::query(url_decoded(contains(("scopes", scopes.join(",")))))
941            ])
942            .respond_with(status_code(503)),
943        );
944
945        let mdsc = Builder::default()
946            .with_endpoint(format!("http://{}", server.addr()))
947            .with_scopes(scopes)
948            .build()?;
949        let err = mdsc.headers(Extensions::new()).await.unwrap_err();
950        let original_err = find_source_error::<CredentialsError>(&err).unwrap();
951        assert!(original_err.is_transient());
952        let source = find_source_error::<reqwest::Error>(&err);
953        assert!(
954            matches!(source, Some(e) if e.status() == Some(StatusCode::SERVICE_UNAVAILABLE)),
955            "{err:?}"
956        );
957
958        Ok(())
959    }
960
961    #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
962    #[parallel]
963    async fn credentials_headers_nonretryable_error() -> TestResult {
964        let server = Server::run();
965        let scopes = vec!["scope1".to_string()];
966        server.expect(
967            Expectation::matching(all_of![
968                request::path(format!("{MDS_DEFAULT_URI}/token")),
969                request::query(url_decoded(contains(("scopes", scopes.join(",")))))
970            ])
971            .respond_with(status_code(401)),
972        );
973
974        let mdsc = Builder::default()
975            .with_endpoint(format!("http://{}", server.addr()))
976            .with_scopes(scopes)
977            .build()?;
978
979        let err = mdsc.headers(Extensions::new()).await.unwrap_err();
980        let original_err = find_source_error::<CredentialsError>(&err).unwrap();
981        assert!(!original_err.is_transient());
982        let source = find_source_error::<reqwest::Error>(&err);
983        assert!(
984            matches!(source, Some(e) if e.status() == Some(StatusCode::UNAUTHORIZED)),
985            "{err:?}"
986        );
987
988        Ok(())
989    }
990
991    #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
992    #[parallel]
993    async fn credentials_headers_malformed_response_is_nonretryable() -> TestResult {
994        let server = Server::run();
995        let scopes = vec!["scope1".to_string()];
996        server.expect(
997            Expectation::matching(all_of![
998                request::path(format!("{MDS_DEFAULT_URI}/token")),
999                request::query(url_decoded(contains(("scopes", scopes.join(",")))))
1000            ])
1001            .respond_with(json_encoded("bad json")),
1002        );
1003
1004        let mdsc = Builder::default()
1005            .with_endpoint(format!("http://{}", server.addr()))
1006            .with_scopes(scopes)
1007            .build()?;
1008
1009        let e = mdsc.headers(Extensions::new()).await.err().unwrap();
1010        assert!(!e.is_transient());
1011
1012        Ok(())
1013    }
1014
1015    #[tokio::test]
1016    async fn get_default_universe_domain_success() -> TestResult {
1017        let universe_domain_response = Builder::default().build()?.universe_domain().await.unwrap();
1018        assert_eq!(universe_domain_response, DEFAULT_UNIVERSE_DOMAIN);
1019        Ok(())
1020    }
1021}